Merge pull request #1 from AnishSarkar22/feat/test-ci

testing backend test CI workflow
This commit is contained in:
Anish Sarkar 2026-03-08 02:57:13 +05:30 committed by GitHub
commit f53759f0e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
220 changed files with 16886 additions and 6950 deletions

112
.cursor/skills/tdd/SKILL.md Normal file
View file

@ -0,0 +1,112 @@
---
name: tdd
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
---
---
name: tdd
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
---
# Test-Driven Development
## Philosophy
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
## Anti-Pattern: Horizontal Slices
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
This produces **crap tests**:
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
- You outrun your headlights, committing to test structure before understanding the implementation
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
```
WRONG (horizontal):
RED: test1, test2, test3, test4, test5
GREEN: impl1, impl2, impl3, impl4, impl5
RIGHT (vertical):
RED→GREEN: test1→impl1
RED→GREEN: test2→impl2
RED→GREEN: test3→impl3
...
```
## Workflow
### 1. Planning
Before writing any code:
- [ ] Confirm with user what interface changes are needed
- [ ] Confirm with user which behaviors to test (prioritize)
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
- [ ] Design interfaces for [testability](interface-design.md)
- [ ] List the behaviors to test (not implementation steps)
- [ ] Get user approval on the plan
Ask: "What should the public interface look like? Which behaviors are most important to test?"
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
### 2. Tracer Bullet
Write ONE test that confirms ONE thing about the system:
```
RED: Write test for first behavior → test fails
GREEN: Write minimal code to pass → test passes
```
This is your tracer bullet - proves the path works end-to-end.
### 3. Incremental Loop
For each remaining behavior:
```
RED: Write next test → fails
GREEN: Minimal code to pass → passes
```
Rules:
- One test at a time
- Only enough code to pass current test
- Don't anticipate future tests
- Keep tests focused on observable behavior
### 4. Refactor
After all tests pass, look for [refactor candidates](refactoring.md):
- [ ] Extract duplication
- [ ] Deepen modules (move complexity behind simple interfaces)
- [ ] Apply SOLID principles where natural
- [ ] Consider what new code reveals about existing code
- [ ] Run tests after each refactor step
**Never refactor while RED.** Get to GREEN first.
## Checklist Per Cycle
```
[ ] Test describes behavior, not implementation
[ ] Test uses public interface only
[ ] Test would survive internal refactor
[ ] Code is minimal for this test
[ ] No speculative features added
```

View file

@ -0,0 +1,33 @@
# Deep Modules
From "A Philosophy of Software Design":
**Deep module** = small interface + lots of implementation
```
┌─────────────────────┐
│ Small Interface │ ← Few methods, simple params
├─────────────────────┤
│ │
│ │
│ Deep Implementation│ ← Complex logic hidden
│ │
│ │
└─────────────────────┘
```
**Shallow module** = large interface + little implementation (avoid)
```
┌─────────────────────────────────┐
│ Large Interface │ ← Many methods, complex params
├─────────────────────────────────┤
│ Thin Implementation │ ← Just passes through
└─────────────────────────────────┘
```
When designing interfaces, ask:
- Can I reduce the number of methods?
- Can I simplify the parameters?
- Can I hide more complexity inside?

View file

@ -0,0 +1,33 @@
# Interface Design for Testability
Good interfaces make testing natural:
1. **Accept dependencies, don't create them**
```python
# Testable
def process_order(order, payment_gateway):
pass
# Hard to test
def process_order(order):
gateway = StripeGateway()
```
2. **Return results, don't produce side effects**
```python
# Testable
def calculate_discount(cart) -> float:
return discount
# Hard to test
def apply_discount(cart) -> None:
cart.total -= discount
```
3. **Small surface area**
* Fewer methods = fewer tests needed
* Fewer params = simpler test setup

View file

@ -0,0 +1,69 @@
# When to Mock
Mock at **system boundaries** only:
* External APIs (payment, email, etc.)
* Databases (sometimes - prefer test DB)
* Time/randomness
* File system (sometimes)
Don't mock:
* Your own classes/modules
* Internal collaborators
* Anything you control
## Designing for Mockability
At system boundaries, design interfaces that are easy to mock:
**1. Use dependency injection**
Pass external dependencies in rather than creating them internally:
```python
import os
# Easy to mock
def process_payment(order, payment_client):
return payment_client.charge(order.total)
# Hard to mock
def process_payment(order):
client = StripeClient(os.getenv("STRIPE_KEY"))
return client.charge(order.total)
```
**2. Prefer SDK-style interfaces over generic fetchers**
Create specific functions for each external operation instead of one generic function with conditional logic:
```python
import requests
# GOOD: Each function is independently mockable
class UserAPI:
def get_user(self, user_id):
return requests.get(f"/users/{user_id}")
def get_orders(self, user_id):
return requests.get(f"/users/{user_id}/orders")
def create_order(self, data):
return requests.post("/orders", json=data)
# BAD: Mocking requires conditional logic inside the mock
class GenericAPI:
def fetch(self, endpoint, method="GET", data=None):
return requests.request(method, endpoint, json=data)
```
The SDK approach means:
* Each mock returns one specific shape
* No conditional logic in test setup
* Easier to see which endpoints a test exercises
* Type safety per endpoint

View file

@ -0,0 +1,10 @@
# Refactor Candidates
After TDD cycle, look for:
- **Duplication** → Extract function/class
- **Long methods** → Break into private helpers (keep tests on public interface)
- **Shallow modules** → Combine or deepen
- **Feature envy** → Move logic to where data lives
- **Primitive obsession** → Introduce value objects
- **Existing code** the new code reveals as problematic

View file

@ -0,0 +1,60 @@
# Good and Bad Tests
## Good Tests
**Integration-style**: Test through real interfaces, not mocks of internal parts.
```python
# GOOD: Tests observable behavior
def test_user_can_checkout_with_valid_cart():
cart = create_cart()
cart.add(product)
result = checkout(cart, payment_method)
assert result.status == "confirmed"
```
Characteristics:
* Tests behavior users/callers care about
* Uses public API only
* Survives internal refactors
* Describes WHAT, not HOW
* One logical assertion per test
## Bad Tests
**Implementation-detail tests**: Coupled to internal structure.
```python
# BAD: Tests implementation details
def test_checkout_calls_payment_service_process():
mock_payment = MagicMock()
checkout(cart, mock_payment)
mock_payment.process.assert_called_with(cart.total)
```
Red flags:
* Mocking internal collaborators
* Testing private methods
* Asserting on call counts/order
* Test breaks when refactoring without behavior change
* Test name describes HOW not WHAT
* Verifying through external means instead of interface
```python
# BAD: Bypasses interface to verify
def test_create_user_saves_to_database():
create_user({"name": "Alice"})
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
assert row is not None
# GOOD: Verifies through interface
def test_create_user_makes_user_retrievable():
user = create_user({"name": "Alice"})
retrieved = get_user(user.id)
assert retrieved.name == "Alice"
```

View file

@ -1,41 +0,0 @@
# Docker Specific Env's Only - Can skip if needed
# Celery Config
REDIS_PORT=6379
FLOWER_PORT=5555
# Frontend Configuration
FRONTEND_PORT=3000
NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 (Default: http://localhost:8000)
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE (Default: LOCAL)
NEXT_PUBLIC_ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING (Default: DOCLING)
# Backend Configuration
BACKEND_PORT=8000
# Auth type for backend login flow (Default: LOCAL)
# Set to GOOGLE if using Google OAuth
AUTH_TYPE=LOCAL
# Frontend URL used by backend for CORS allowed origins and OAuth redirects
# Must match the URL your browser uses to access the frontend
NEXT_FRONTEND_URL=http://localhost:3000
# Database Configuration
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres
POSTGRES_DB=surfsense
POSTGRES_PORT=5432
# Electric-SQL Configuration
ELECTRIC_PORT=5133
# PostgreSQL host for Electric connection
# - 'db' for Docker PostgreSQL (service name in docker-compose)
# - 'host.docker.internal' for local PostgreSQL (recommended when Electric runs in Docker)
# Note: host.docker.internal works on Docker Desktop (Mac/Windows) and can be enabled on Linux
POSTGRES_HOST=db
ELECTRIC_DB_USER=electric
ELECTRIC_DB_PASSWORD=electric_password
NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
# pgAdmin Configuration
PGADMIN_PORT=5050
PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
PGADMIN_DEFAULT_PASSWORD=surfsense

161
.github/workflows/backend-tests.yml vendored Normal file
View file

@ -0,0 +1,161 @@
name: Backend Tests
on:
pull_request:
branches: [main, dev]
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
unit-tests:
name: Unit Tests
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
env:
EMBEDDING_MODEL: sentence-transformers/all-MiniLM-L6-v2
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: Check if backend files changed
id: backend-changes
uses: dorny/paths-filter@v3
with:
filters: |
backend:
- 'surfsense_backend/**'
- name: Set up Python
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install UV
if: steps.backend-changes.outputs.backend == 'true'
uses: astral-sh/setup-uv@v7
- name: Cache dependencies
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: |
~/.cache/uv
surfsense_backend/.venv
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
restore-keys: |
python-deps-
- name: Cache HuggingFace models
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: ~/.cache/huggingface
key: hf-models-${{ env.EMBEDDING_MODEL }}
- name: Install dependencies
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv sync
- name: Run unit tests
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv run pytest -m unit
integration-tests:
name: Integration Tests
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
env:
EMBEDDING_MODEL: sentence-transformers/all-MiniLM-L6-v2
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: surfsense_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U postgres -d surfsense_test"
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: Check if backend files changed
id: backend-changes
uses: dorny/paths-filter@v3
with:
filters: |
backend:
- 'surfsense_backend/**'
- name: Set up Python
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install UV
if: steps.backend-changes.outputs.backend == 'true'
uses: astral-sh/setup-uv@v7
- name: Cache dependencies
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: |
~/.cache/uv
surfsense_backend/.venv
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
restore-keys: |
python-deps-
- name: Cache HuggingFace models
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v5
with:
path: ~/.cache/huggingface
key: hf-models-${{ env.EMBEDDING_MODEL }}
- name: Install dependencies
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv sync
- name: Run integration tests
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
env:
TEST_DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test
SECRET_KEY: ci-test-secret-key-not-for-production
ETL_SERVICE: DOCLING
run: uv run pytest -m integration
test-gate:
name: Test Gate
runs-on: ubuntu-latest
needs: [unit-tests, integration-tests]
if: always()
steps:
- name: Check all test jobs
run: |
if [[ "${{ needs.unit-tests.result }}" == "failure" ||
"${{ needs.integration-tests.result }}" == "failure" ]]; then
echo "Backend tests failed"
exit 1
else
echo "All backend tests passed"
fi

View file

@ -1,6 +1,13 @@
name: Build and Push Docker Image
name: Build and Push Docker Images
on:
push:
branches:
- main
- dev
paths:
- 'surfsense_backend/**'
- 'surfsense_web/**'
workflow_dispatch:
inputs:
branch:
@ -8,6 +15,10 @@ on:
required: false
default: ''
concurrency:
group: docker-build
cancel-in-progress: false
permissions:
contents: write
packages: write
@ -28,33 +39,28 @@ jobs:
- name: Read app version and calculate next Docker build version
id: tag_version
run: |
# Read version from pyproject.toml
APP_VERSION=$(grep -E '^version = ' surfsense_backend/pyproject.toml | sed 's/version = "\(.*\)"/\1/')
echo "App version from pyproject.toml: $APP_VERSION"
if [ -z "$APP_VERSION" ]; then
echo "Error: Could not read version from surfsense_backend/pyproject.toml"
exit 1
fi
# Fetch all tags
git fetch --tags
# Find the latest docker build tag for this app version (format: APP_VERSION.BUILD_NUMBER)
# Tags follow pattern: 0.0.11.1, 0.0.11.2, etc.
LATEST_BUILD_TAG=$(git tag --list "${APP_VERSION}.*" --sort='-v:refname' | head -n 1)
if [ -z "$LATEST_BUILD_TAG" ]; then
echo "No previous Docker build tag found for version ${APP_VERSION}. Starting with ${APP_VERSION}.1"
NEXT_VERSION="${APP_VERSION}.1"
else
echo "Latest Docker build tag found: $LATEST_BUILD_TAG"
# Extract the build number (4th component)
BUILD_NUMBER=$(echo "$LATEST_BUILD_TAG" | rev | cut -d. -f1 | rev)
NEXT_BUILD=$((BUILD_NUMBER + 1))
NEXT_VERSION="${APP_VERSION}.${NEXT_BUILD}"
fi
echo "Calculated next Docker version: $NEXT_VERSION"
echo "next_version=$NEXT_VERSION" >> $GITHUB_OUTPUT
@ -78,67 +84,35 @@ jobs:
git ls-remote --tags origin | grep "refs/tags/${{ steps.tag_version.outputs.next_version }}" || (echo "Tag push verification failed!" && exit 1)
echo "Tag successfully pushed."
# Build for AMD64 on native x64 runner
build_amd64:
runs-on: ubuntu-latest
build:
needs: tag_release
runs-on: ${{ matrix.os }}
permissions:
packages: write
contents: read
outputs:
digest: ${{ steps.build.outputs.digest }}
strategy:
fail-fast: false
matrix:
platform: [linux/amd64, linux/arm64]
image: [backend, web]
include:
- platform: linux/amd64
suffix: amd64
os: ubuntu-latest
- platform: linux/arm64
suffix: arm64
os: ubuntu-24.04-arm
- image: backend
name: surfsense-backend
context: ./surfsense_backend
file: ./surfsense_backend/Dockerfile
- image: web
name: surfsense-web
context: ./surfsense_web
file: ./surfsense_web/Dockerfile
env:
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/surfsense
steps:
- name: Checkout code
uses: actions/checkout@v4
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ matrix.name }}
- name: Set lowercase image name
id: image
run: echo "name=${REGISTRY_IMAGE,,}" >> $GITHUB_OUTPUT
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /usr/local/share/boost
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
docker system prune -af
- name: Build and push AMD64 image
id: build
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile.allinone
push: true
tags: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}-amd64
platforms: linux/amd64
cache-from: type=gha,scope=amd64
cache-to: type=gha,mode=max,scope=amd64
provenance: false
# Build for ARM64 on native arm64 runner (no QEMU emulation!)
build_arm64:
runs-on: ubuntu-24.04-arm
needs: tag_release
permissions:
packages: write
contents: read
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/surfsense
steps:
- name: Checkout code
uses: actions/checkout@v4
@ -165,28 +139,41 @@ jobs:
sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
docker system prune -af
- name: Build and push ARM64 image
- name: Build and push ${{ matrix.name }} (${{ matrix.suffix }})
id: build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile.allinone
context: ${{ matrix.context }}
file: ${{ matrix.file }}
push: true
tags: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}-arm64
platforms: linux/arm64
cache-from: type=gha,scope=arm64
cache-to: type=gha,mode=max,scope=arm64
tags: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}-${{ matrix.suffix }}
platforms: ${{ matrix.platform }}
cache-from: type=gha,scope=${{ matrix.image }}-${{ matrix.suffix }}
cache-to: type=gha,mode=max,scope=${{ matrix.image }}-${{ matrix.suffix }}
provenance: false
build-args: |
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
# Create multi-arch manifest combining both platform images
create_manifest:
runs-on: ubuntu-latest
needs: [tag_release, build_amd64, build_arm64]
needs: [tag_release, build]
permissions:
packages: write
contents: read
strategy:
fail-fast: false
matrix:
include:
- name: surfsense-backend
- name: surfsense-web
env:
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/surfsense
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ matrix.name }}
steps:
- name: Set lowercase image name
id: image
@ -203,28 +190,31 @@ jobs:
run: |
VERSION_TAG="${{ needs.tag_release.outputs.new_tag }}"
IMAGE="${{ steps.image.outputs.name }}"
# Create manifest for version tag
APP_VERSION=$(echo "$VERSION_TAG" | rev | cut -d. -f2- | rev)
docker manifest create ${IMAGE}:${VERSION_TAG} \
${IMAGE}:${VERSION_TAG}-amd64 \
${IMAGE}:${VERSION_TAG}-arm64
docker manifest push ${IMAGE}:${VERSION_TAG}
# Create/update latest tag if on default branch
if [[ "${{ github.ref }}" == "refs/heads/${{ github.event.repository.default_branch }}" ]] || [[ "${{ github.event.inputs.branch }}" == "${{ github.event.repository.default_branch }}" ]]; then
docker manifest create ${IMAGE}:${APP_VERSION} \
${IMAGE}:${VERSION_TAG}-amd64 \
${IMAGE}:${VERSION_TAG}-arm64
docker manifest push ${IMAGE}:${APP_VERSION}
docker manifest create ${IMAGE}:latest \
${IMAGE}:${VERSION_TAG}-amd64 \
${IMAGE}:${VERSION_TAG}-arm64
docker manifest push ${IMAGE}:latest
fi
- name: Clean up architecture-specific tags (optional)
continue-on-error: true
- name: Summary
run: |
# Note: GHCR doesn't support tag deletion via API easily
# The arch-specific tags will remain but users should use the main tags
echo "Multi-arch manifest created successfully!"
echo "Users should pull: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}"
echo "Or for latest: ${{ steps.image.outputs.name }}:latest"
echo "Multi-arch manifest created for ${{ matrix.name }}!"
echo "Versioned: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}"
echo "App version: ${{ steps.image.outputs.name }}:$(echo '${{ needs.tag_release.outputs.new_tag }}' | rev | cut -d. -f2- | rev)"
echo "Latest: ${{ steps.image.outputs.name }}:latest"

1
.gitignore vendored
View file

@ -5,3 +5,4 @@ node_modules/
.ruff_cache/
.venv
.pnpm-store
.DS_Store

View file

@ -1,285 +0,0 @@
# SurfSense All-in-One Docker Image
# This image bundles PostgreSQL+pgvector, Redis, Electric SQL, Backend, and Frontend
# Usage: docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense ghcr.io/modsetter/surfsense:latest
#
# Included Services (all run locally by default):
# - PostgreSQL 14 + pgvector (vector database)
# - Redis (task queue)
# - Electric SQL (real-time sync)
# - Docling (document processing, CPU-only, OCR disabled)
# - Kokoro TTS (local text-to-speech for podcasts)
# - Faster-Whisper (local speech-to-text for audio files)
# - Playwright Chromium (web scraping)
#
# Note: This is the CPU-only version. A :cuda tagged image with GPU support
# will be available in the future for faster AI inference.
# ====================
# Stage 1: Get Electric SQL Binary
# ====================
FROM electricsql/electric:latest AS electric-builder
# ====================
# Stage 2: Build Frontend
# ====================
FROM node:20-alpine AS frontend-builder
WORKDIR /app
# Install pnpm
RUN corepack enable pnpm
# Copy package files
COPY surfsense_web/package.json surfsense_web/pnpm-lock.yaml* ./
COPY surfsense_web/source.config.ts ./
COPY surfsense_web/content ./content
# Install dependencies (skip postinstall which requires all source files)
RUN pnpm install --frozen-lockfile --ignore-scripts
# Copy source
COPY surfsense_web/ ./
# Run fumadocs-mdx postinstall now that source files are available
RUN pnpm fumadocs-mdx
# Build with placeholder values that will be replaced at runtime
# These unique strings allow runtime substitution via entrypoint script
ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__
ENV NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__
ENV NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__
ENV NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__
ENV NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__
# Build
RUN pnpm run build
# ====================
# Stage 3: Runtime Image
# ====================
FROM ubuntu:22.04 AS runtime
# Prevent interactive prompts
ENV DEBIAN_FRONTEND=noninteractive
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
# PostgreSQL
postgresql-14 \
postgresql-contrib-14 \
# Build tools for pgvector
build-essential \
postgresql-server-dev-14 \
git \
# Redis
redis-server \
# Node.js prerequisites
curl \
ca-certificates \
gnupg \
# Backend dependencies
gcc \
wget \
unzip \
dos2unix \
# For PPAs
software-properties-common \
# ============================
# Local TTS (Kokoro) dependencies
# ============================
espeak-ng \
libespeak-ng1 \
# ============================
# Local STT (Faster-Whisper) dependencies
# ============================
ffmpeg \
# ============================
# Audio processing (soundfile)
# ============================
libsndfile1 \
# ============================
# Image/OpenCV dependencies (for Docling)
# ============================
libgl1 \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender1 \
# ============================
# Playwright browser dependencies
# ============================
libnspr4 \
libnss3 \
libatk1.0-0 \
libatk-bridge2.0-0 \
libcups2 \
libxkbcommon0 \
libatspi2.0-0 \
libxcomposite1 \
libxdamage1 \
libxrandr2 \
libgbm1 \
libcairo2 \
libpango-1.0-0 \
&& rm -rf /var/lib/apt/lists/*
# Install Pandoc 3.x from GitHub (apt ships 2.9 which has broken table rendering).
RUN ARCH=$(dpkg --print-architecture) && \
wget -qO /tmp/pandoc.deb "https://github.com/jgm/pandoc/releases/download/3.9/pandoc-3.9-1-${ARCH}.deb" && \
dpkg -i /tmp/pandoc.deb && \
rm /tmp/pandoc.deb
# Install Node.js 20.x (for running frontend)
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs \
&& rm -rf /var/lib/apt/lists/*
# Install Python 3.12 from deadsnakes PPA
RUN add-apt-repository ppa:deadsnakes/ppa -y \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
python3.12 \
python3.12-venv \
python3.12-dev \
&& rm -rf /var/lib/apt/lists/*
# Set Python 3.12 as default
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
# Install pip for Python 3.12
RUN python3.12 -m ensurepip --upgrade \
&& python3.12 -m pip install --upgrade pip
# Install supervisor via pip (system package incompatible with Python 3.12)
RUN pip install --no-cache-dir supervisor
# Build and install pgvector
RUN cd /tmp \
&& git clone --branch v0.7.4 https://github.com/pgvector/pgvector.git \
&& cd pgvector \
&& make \
&& make install \
&& rm -rf /tmp/pgvector
# Update certificates
RUN update-ca-certificates
# Create data directories
RUN mkdir -p /data/postgres /data/redis /data/surfsense \
&& chown -R postgres:postgres /data/postgres
# ====================
# Copy Frontend Build
# ====================
WORKDIR /app/frontend
# Copy only the standalone build (not node_modules)
COPY --from=frontend-builder /app/.next/standalone ./
COPY --from=frontend-builder /app/.next/static ./.next/static
COPY --from=frontend-builder /app/public ./public
COPY surfsense_web/content/docs /app/surfsense_web/content/docs
# ====================
# Copy Electric SQL Release
# ====================
COPY --from=electric-builder /app /app/electric-release
# ====================
# Setup Backend
# ====================
WORKDIR /app/backend
# Copy backend dependency files
COPY surfsense_backend/pyproject.toml surfsense_backend/uv.lock ./
# Install PyTorch CPU-only (Docling needs it but OCR is disabled, no GPU needed)
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
# Install python dependencies
RUN pip install --no-cache-dir certifi pip-system-certs uv \
&& uv pip install --system --no-cache-dir -e .
# Set SSL environment variables
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") \
&& echo "export SSL_CERT_FILE=$CERTIFI_PATH" >> /etc/profile.d/ssl.sh \
&& echo "export REQUESTS_CA_BUNDLE=$CERTIFI_PATH" >> /etc/profile.d/ssl.sh
# Note: EasyOCR models NOT downloaded - OCR is disabled in docling_service.py
# GPU support will be added in a future :cuda tagged image
# Install Playwright browsers
RUN pip install --no-cache-dir playwright \
&& playwright install chromium \
&& rm -rf /root/.cache/ms-playwright/ffmpeg*
# Copy backend source
COPY surfsense_backend/ ./
# ====================
# Configuration
# ====================
WORKDIR /app
# Copy supervisor configuration
COPY scripts/docker/supervisor-allinone.conf /etc/supervisor/conf.d/surfsense.conf
# Copy entrypoint script
COPY scripts/docker/entrypoint-allinone.sh /app/entrypoint.sh
RUN dos2unix /app/entrypoint.sh && chmod +x /app/entrypoint.sh
# PostgreSQL initialization script
COPY scripts/docker/init-postgres.sh /app/init-postgres.sh
RUN dos2unix /app/init-postgres.sh && chmod +x /app/init-postgres.sh
# Clean up build dependencies to reduce image size
RUN apt-get purge -y build-essential postgresql-server-dev-14 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Environment variables with defaults
ENV POSTGRES_USER=surfsense
ENV POSTGRES_PASSWORD=surfsense
ENV POSTGRES_DB=surfsense
ENV DATABASE_URL=postgresql+asyncpg://surfsense:surfsense@localhost:5432/surfsense
ENV CELERY_BROKER_URL=redis://localhost:6379/0
ENV CELERY_RESULT_BACKEND=redis://localhost:6379/0
ENV CELERY_TASK_DEFAULT_QUEUE=surfsense
ENV PYTHONPATH=/app/backend
ENV NEXT_FRONTEND_URL=http://localhost:3000
ENV AUTH_TYPE=LOCAL
ENV ETL_SERVICE=DOCLING
ENV EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# Frontend configuration (can be overridden at runtime)
# These are injected into the Next.js build at container startup
ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
ENV NEXT_PUBLIC_ETL_SERVICE=DOCLING
# Electric SQL configuration (ELECTRIC_DATABASE_URL is built dynamically by entrypoint from these values)
ENV ELECTRIC_DB_USER=electric
ENV ELECTRIC_DB_PASSWORD=electric_password
# Note: ELECTRIC_DATABASE_URL is NOT set here - entrypoint builds it dynamically from ELECTRIC_DB_USER/PASSWORD
ENV ELECTRIC_INSECURE=true
ENV ELECTRIC_WRITE_TO_PG_MODE=direct
ENV ELECTRIC_PORT=5133
ENV PORT=5133
ENV NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
ENV NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
# Data volume
VOLUME ["/data"]
# Expose ports (Frontend: 3000, Backend: 8000, Electric: 5133)
EXPOSE 3000 8000 5133
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
CMD curl -f http://localhost:3000 || exit 1
# Run entrypoint
CMD ["/app/entrypoint.sh"]

View file

@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
Ejecuta SurfSense en tu propia infraestructura para control total de datos y privacidad.
**Inicio Rápido (Docker en un solo comando):**
**Requisitos previos:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) debe estar instalado y en ejecución.
#### Para usuarios de Linux/MacOS:
```bash
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
```
Después de iniciar, abre [http://localhost:3000](http://localhost:3000) en tu navegador.
#### Para usuarios de Windows:
```powershell
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
```
El script de instalación configura [Watchtower](https://github.com/nicholas-fedor/watchtower) automáticamente para actualizaciones diarias. Para omitirlo, agrega la bandera `--no-watchtower`.
Para Docker Compose, instalación manual y otras opciones de despliegue, consulta la [documentación](https://www.surfsense.com/docs/).

View file

@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
पूर्ण डेटा नियंत्रण और गोपनीयता के लिए SurfSense को अपने स्वयं के बुनियादी ढांचे पर चलाएं।
**त्वरित शुरुआत (Docker एक कमांड में):**
**आवश्यकताएँ:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) इंस्टॉल और चालू होना चाहिए।
#### Linux/MacOS उपयोगकर्ताओं के लिए:
```bash
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
```
शुरू करने के बाद, अपने ब्राउज़र में [http://localhost:3000](http://localhost:3000) खोलें।
#### Windows उपयोगकर्ताओं के लिए:
```powershell
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
```
इंस्टॉल स्क्रिप्ट दैनिक ऑटो-अपडेट के लिए स्वचालित रूप से [Watchtower](https://github.com/nicholas-fedor/watchtower) सेटअप करती है। इसे छोड़ने के लिए, `--no-watchtower` फ्लैग जोड़ें।
Docker Compose, मैनुअल इंस्टॉलेशन और अन्य डिप्लॉयमेंट विकल्पों के लिए, [डॉक्स](https://www.surfsense.com/docs/) देखें।

View file

@ -81,21 +81,23 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
Run SurfSense on your own infrastructure for full data control and privacy.
**Quick Start (Docker one-liner):**
**Prerequisites:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) must be installed and running.
#### For Linux/MacOS users:
```bash
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
```
After starting, open [http://localhost:3000](http://localhost:3000) in your browser.
**Update (Automatic updates with Watchtower):**
#### For Windows users:
```bash
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock nickfedor/watchtower --run-once surfsense
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
```
For Docker Compose, manual installation, and other deployment options, check the [docs](https://www.surfsense.com/docs/).
The install script sets up [Watchtower](https://github.com/nicholas-fedor/watchtower) automatically for daily auto-updates. To skip it, add the `--no-watchtower` flag.
For Docker Compose, manual installation, and other deployment options, see the [docs](https://www.surfsense.com/docs/).
### How to Realtime Collaborate (Beta)

View file

@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
Execute o SurfSense na sua própria infraestrutura para controle total de dados e privacidade.
**Início Rápido (Docker em um único comando):**
**Pré-requisitos:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) deve estar instalado e em execução.
#### Para usuários de Linux/MacOS:
```bash
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
```
Após iniciar, abra [http://localhost:3000](http://localhost:3000) no seu navegador.
#### Para usuários do Windows:
```powershell
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
```
O script de instalação configura o [Watchtower](https://github.com/nicholas-fedor/watchtower) automaticamente para atualizações diárias. Para pular, adicione a flag `--no-watchtower`.
Para Docker Compose, instalação manual e outras opções de implantação, consulte a [documentação](https://www.surfsense.com/docs/).

View file

@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
在您自己的基础设施上运行 SurfSense实现完全的数据控制和隐私保护。
**快速开始Docker 一行命令):**
**前置条件:** 需要安装并运行 [Docker Desktop](https://www.docker.com/products/docker-desktop/)。
#### Linux/MacOS 用户:
```bash
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
```
启动后,在浏览器中打开 [http://localhost:3000](http://localhost:3000)。
#### Windows 用户:
```powershell
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
```
安装脚本会自动配置 [Watchtower](https://github.com/nicholas-fedor/watchtower) 以实现每日自动更新。如需跳过,请添加 `--no-watchtower` 参数。
如需 Docker Compose、手动安装及其他部署方式请查看[文档](https://www.surfsense.com/docs/)。

View file

@ -1,80 +0,0 @@
# SurfSense Quick Start Docker Compose
#
# This is a simplified docker-compose for quick local deployment using pre-built images.
# For production or customized deployments, use the main docker-compose.yml
#
# Usage:
# 1. (Optional) Create a .env file with your configuration
# 2. Run: docker compose -f docker-compose.quickstart.yml up -d
# 3. Access SurfSense at http://localhost:3000
#
# All Environment Variables are Optional:
# - SECRET_KEY: JWT secret key (auto-generated and persisted if not set)
# - EMBEDDING_MODEL: Embedding model to use (default: sentence-transformers/all-MiniLM-L6-v2)
# - ETL_SERVICE: Document parsing service - DOCLING, UNSTRUCTURED, or LLAMACLOUD (default: DOCLING)
# - TTS_SERVICE: Text-to-speech service for podcasts (default: local/kokoro)
# - STT_SERVICE: Speech-to-text service with model size (default: local/base)
# - FIRECRAWL_API_KEY: For web crawling features
version: "3.8"
services:
# All-in-one SurfSense container
surfsense:
image: ghcr.io/modsetter/surfsense:latest
container_name: surfsense
ports:
- "${FRONTEND_PORT:-3000}:3000"
- "${BACKEND_PORT:-8000}:8000"
volumes:
- surfsense-data:/data
environment:
# Authentication (auto-generated if not set)
- SECRET_KEY=${SECRET_KEY:-}
# Auth Configuration
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
# AI/ML Configuration
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- RERANKERS_ENABLED=${RERANKERS_ENABLED:-FALSE}
- RERANKERS_MODEL_NAME=${RERANKERS_MODEL_NAME:-}
- RERANKERS_MODEL_TYPE=${RERANKERS_MODEL_TYPE:-}
# Document Processing
- ETL_SERVICE=${ETL_SERVICE:-DOCLING}
- UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
- LLAMA_CLOUD_API_KEY=${LLAMA_CLOUD_API_KEY:-}
# Audio Services
- TTS_SERVICE=${TTS_SERVICE:-local/kokoro}
- TTS_SERVICE_API_KEY=${TTS_SERVICE_API_KEY:-}
- STT_SERVICE=${STT_SERVICE:-local/base}
- STT_SERVICE_API_KEY=${STT_SERVICE_API_KEY:-}
# Web Crawling
- FIRECRAWL_API_KEY=${FIRECRAWL_API_KEY:-}
# Optional Features
- REGISTRATION_ENABLED=${REGISTRATION_ENABLED:-TRUE}
- SCHEDULE_CHECKER_INTERVAL=${SCHEDULE_CHECKER_INTERVAL:-1m}
# LangSmith Observability (optional)
- LANGSMITH_TRACING=${LANGSMITH_TRACING:-false}
- LANGSMITH_ENDPOINT=${LANGSMITH_ENDPOINT:-}
- LANGSMITH_API_KEY=${LANGSMITH_API_KEY:-}
- LANGSMITH_PROJECT=${LANGSMITH_PROJECT:-}
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000", "&&", "curl", "-f", "http://localhost:8000/docs"]
interval: 30s
timeout: 10s
retries: 3
start_period: 120s
volumes:
surfsense-data:
name: surfsense-data

View file

@ -1,167 +0,0 @@
version: "3.8"
services:
db:
image: ankane/pgvector:latest
ports:
- "${POSTGRES_PORT:-5432}:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
- ./scripts/docker/postgresql.conf:/etc/postgresql/postgresql.conf:ro
- ./scripts/docker/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-postgres}
- POSTGRES_DB=${POSTGRES_DB:-surfsense}
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
command: postgres -c config_file=/etc/postgresql/postgresql.conf
pgadmin:
image: dpage/pgadmin4
ports:
- "${PGADMIN_PORT:-5050}:80"
environment:
- PGADMIN_DEFAULT_EMAIL=${PGADMIN_DEFAULT_EMAIL:-admin@surfsense.com}
- PGADMIN_DEFAULT_PASSWORD=${PGADMIN_DEFAULT_PASSWORD:-surfsense}
volumes:
- pgadmin_data:/var/lib/pgadmin
depends_on:
- db
redis:
image: redis:7-alpine
ports:
- "${REDIS_PORT:-6379}:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
backend:
build: ./surfsense_backend
# image: ghcr.io/modsetter/surfsense_backend:latest
ports:
- "${BACKEND_PORT:-8000}:8000"
volumes:
- ./surfsense_backend/app:/app/app
- shared_temp:/tmp
# Uncomment and edit the line below to enable Obsidian vault indexing
# - /path/to/your/obsidian/vault:/obsidian-vault:ro
env_file:
- ./surfsense_backend/.env
environment:
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
- CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
- CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
- REDIS_APP_URL=redis://redis:${REDIS_PORT:-6379}/0
# Queue name isolation - prevents task collision if Redis is shared with other apps
- CELERY_TASK_DEFAULT_QUEUE=surfsense
- PYTHONPATH=/app
- UVICORN_LOOP=asyncio
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
- LANGCHAIN_TRACING_V2=false
- LANGSMITH_TRACING=false
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
depends_on:
- db
- redis
# Run these services separately in production
# celery_worker:
# build: ./surfsense_backend
# # image: ghcr.io/modsetter/surfsense_backend:latest
# command: celery -A app.celery_app worker --loglevel=info --concurrency=1 --pool=solo
# volumes:
# - ./surfsense_backend:/app
# - shared_temp:/tmp
# env_file:
# - ./surfsense_backend/.env
# environment:
# - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
# - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
# - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
# - PYTHONPATH=/app
# depends_on:
# - db
# - redis
# - backend
# celery_beat:
# build: ./surfsense_backend
# # image: ghcr.io/modsetter/surfsense_backend:latest
# command: celery -A app.celery_app beat --loglevel=info
# volumes:
# - ./surfsense_backend:/app
# - shared_temp:/tmp
# env_file:
# - ./surfsense_backend/.env
# environment:
# - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
# - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
# - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
# - PYTHONPATH=/app
# depends_on:
# - db
# - redis
# - celery_worker
# flower:
# build: ./surfsense_backend
# # image: ghcr.io/modsetter/surfsense_backend:latest
# command: celery -A app.celery_app flower --port=5555
# ports:
# - "${FLOWER_PORT:-5555}:5555"
# env_file:
# - ./surfsense_backend/.env
# environment:
# - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
# - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
# - PYTHONPATH=/app
# depends_on:
# - redis
# - celery_worker
electric:
image: electricsql/electric:latest
ports:
- "${ELECTRIC_PORT:-5133}:3000"
environment:
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${POSTGRES_HOST:-db}:${POSTGRES_PORT:-5432}/${POSTGRES_DB:-surfsense}?sslmode=disable}
- ELECTRIC_INSECURE=true
- ELECTRIC_WRITE_TO_PG_MODE=direct
restart: unless-stopped
# depends_on:
# - db
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
interval: 10s
timeout: 5s
retries: 5
frontend:
build:
context: ./surfsense_web
# image: ghcr.io/modsetter/surfsense_ui:latest
args:
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
ports:
- "${FRONTEND_PORT:-3000}:3000"
env_file:
- ./surfsense_web/.env
environment:
- NEXT_PUBLIC_ELECTRIC_URL=${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
- NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
depends_on:
- backend
- electric
volumes:
postgres_data:
pgadmin_data:
redis_data:
shared_temp:

256
docker/.env.example Normal file
View file

@ -0,0 +1,256 @@
# ==============================================================================
# SurfSense Docker Configuration
# ==============================================================================
# Database, Redis, and internal service wiring are handled automatically.
# ==============================================================================
# SurfSense version (use "latest", a clean version like "0.0.14", or a specific build like "0.0.14.1")
SURFSENSE_VERSION=latest
# ------------------------------------------------------------------------------
# Core Settings
# ------------------------------------------------------------------------------
# REQUIRED: Generate a secret key with: openssl rand -base64 32
SECRET_KEY=replace_me_with_a_random_string
# Auth type: LOCAL (email/password) or GOOGLE (OAuth)
AUTH_TYPE=LOCAL
# Allow new user registrations (TRUE or FALSE)
# REGISTRATION_ENABLED=TRUE
# Document parsing service: DOCLING, UNSTRUCTURED, or LLAMACLOUD
ETL_SERVICE=DOCLING
# Embedding model for vector search
# Local: sentence-transformers/all-MiniLM-L6-v2
# OpenAI: openai://text-embedding-ada-002 (set OPENAI_API_KEY below)
# Cohere: cohere://embed-english-light-v3.0 (set COHERE_API_KEY below)
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# ------------------------------------------------------------------------------
# Ports (change to avoid conflicts with other services on your machine)
# ------------------------------------------------------------------------------
# BACKEND_PORT=8000
# FRONTEND_PORT=3000
# ELECTRIC_PORT=5133
# FLOWER_PORT=5555
# ==============================================================================
# DEV COMPOSE ONLY (docker-compose.dev.yml)
# You only need them only if you are running `docker-compose.dev.yml`.
# ==============================================================================
# -- pgAdmin (database GUI) --
# PGADMIN_PORT=5050
# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
# PGADMIN_DEFAULT_PASSWORD=surfsense
# -- Redis exposed port (dev only; Redis is internal-only in prod) --
# REDIS_PORT=6379
# -- Frontend Build Args --
# In dev, the frontend is built from source and these are passed as build args.
# In prod, they are automatically derived from AUTH_TYPE, ETL_SERVICE, and the port settings above.
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
# NEXT_PUBLIC_ETL_SERVICE=DOCLING
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
# NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
# ------------------------------------------------------------------------------
# Custom Domain / Reverse Proxy
# ------------------------------------------------------------------------------
# ONLY set these if you are serving SurfSense on a real domain via a reverse
# proxy (e.g. Caddy, Nginx, Cloudflare Tunnel).
# For standard localhost deployments, leave all of these commented out —
# they are automatically derived from the port settings above.
#
# NEXT_FRONTEND_URL=https://app.yourdomain.com
# BACKEND_URL=https://api.yourdomain.com
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
# NEXT_PUBLIC_ELECTRIC_URL=https://electric.yourdomain.com
# ------------------------------------------------------------------------------
# Database (defaults work out of the box, change for security)
# ------------------------------------------------------------------------------
# DB_USER=surfsense
# DB_PASSWORD=surfsense
# DB_NAME=surfsense
# DB_HOST=db
# DB_PORT=5432
# SSL mode for database connections: disable, require, verify-ca, verify-full
# DB_SSLMODE=disable
# Full DATABASE_URL override — when set, takes precedence over the individual
# DB_USER / DB_PASSWORD / DB_NAME / DB_HOST / DB_PORT settings above.
# Use this for managed databases (AWS RDS, GCP Cloud SQL, Supabase, etc.)
# DATABASE_URL=postgresql+asyncpg://user:password@your-rds-host:5432/surfsense?sslmode=require
# ------------------------------------------------------------------------------
# Redis (defaults work out of the box)
# ------------------------------------------------------------------------------
# Full Redis URL override for Celery broker, result backend, and app cache.
# Use this for managed Redis (AWS ElastiCache, Redis Cloud, etc.)
# Supports auth: redis://:password@host:port/0
# Supports TLS: rediss://:password@host:6380/0
# REDIS_URL=redis://redis:6379/0
# ------------------------------------------------------------------------------
# Electric SQL (real-time sync credentials)
# ------------------------------------------------------------------------------
# These must match on the db, backend, and electric services.
# Change for security; defaults work out of the box.
# ELECTRIC_DB_USER=electric
# ELECTRIC_DB_PASSWORD=electric_password
# Full override for the Electric → Postgres connection URL.
# Leave commented out to use the Docker-managed `db` container (default).
# Uncomment and set `db` to `host.docker.internal` when pointing Electric at a local Postgres instance (e.g. Postgres.app on macOS):
# ELECTRIC_DATABASE_URL=postgresql://electric:electric_password@db:5432/surfsense?sslmode=disable
# ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text)
# ------------------------------------------------------------------------------
# Local Kokoro TTS (default) or LiteLLM provider
TTS_SERVICE=local/kokoro
# TTS_SERVICE_API_KEY=
# TTS_SERVICE_API_BASE=
# Local Faster-Whisper STT: local/MODEL_SIZE (tiny, base, small, medium, large-v3)
STT_SERVICE=local/base
# Or use LiteLLM: openai/whisper-1
# STT_SERVICE_API_KEY=
# STT_SERVICE_API_BASE=
# ------------------------------------------------------------------------------
# Rerankers (optional, disabled by default)
# ------------------------------------------------------------------------------
# RERANKERS_ENABLED=TRUE
# RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
# RERANKERS_MODEL_TYPE=flashrank
# ------------------------------------------------------------------------------
# Google OAuth (only if AUTH_TYPE=GOOGLE)
# ------------------------------------------------------------------------------
# GOOGLE_OAUTH_CLIENT_ID=
# GOOGLE_OAUTH_CLIENT_SECRET=
# ------------------------------------------------------------------------------
# Connector OAuth Keys (uncomment connectors you want to use)
# ------------------------------------------------------------------------------
# -- Google Connectors --
# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback
# -- Notion --
# NOTION_CLIENT_ID=
# NOTION_CLIENT_SECRET=
# NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback
# -- Slack --
# SLACK_CLIENT_ID=
# SLACK_CLIENT_SECRET=
# SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback
# -- Discord --
# DISCORD_CLIENT_ID=
# DISCORD_CLIENT_SECRET=
# DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback
# DISCORD_BOT_TOKEN=
# -- Atlassian (Jira & Confluence) --
# ATLASSIAN_CLIENT_ID=
# ATLASSIAN_CLIENT_SECRET=
# JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback
# CONFLUENCE_REDIRECT_URI=http://localhost:8000/api/v1/auth/confluence/connector/callback
# -- Linear --
# LINEAR_CLIENT_ID=
# LINEAR_CLIENT_SECRET=
# LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback
# -- ClickUp --
# CLICKUP_CLIENT_ID=
# CLICKUP_CLIENT_SECRET=
# CLICKUP_REDIRECT_URI=http://localhost:8000/api/v1/auth/clickup/connector/callback
# -- Airtable --
# AIRTABLE_CLIENT_ID=
# AIRTABLE_CLIENT_SECRET=
# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
# -- Microsoft Teams --
# TEAMS_CLIENT_ID=
# TEAMS_CLIENT_SECRET=
# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
# -- Composio --
# COMPOSIO_API_KEY=
# COMPOSIO_ENABLED=TRUE
# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
# ------------------------------------------------------------------------------
# Daytona Sandbox (optional — cloud code execution for the deep agent)
# ------------------------------------------------------------------------------
# Set DAYTONA_SANDBOX_ENABLED=TRUE and provide credentials to give the agent
# an isolated code execution environment via the Daytona cloud API.
# DAYTONA_SANDBOX_ENABLED=FALSE
# DAYTONA_API_KEY=
# DAYTONA_API_URL=https://app.daytona.io/api
# DAYTONA_TARGET=us
# ------------------------------------------------------------------------------
# External API Keys (optional)
# ------------------------------------------------------------------------------
# Firecrawl (web scraping)
# FIRECRAWL_API_KEY=
# Unstructured (if ETL_SERVICE=UNSTRUCTURED)
# UNSTRUCTURED_API_KEY=
# LlamaCloud (if ETL_SERVICE=LLAMACLOUD)
# LLAMA_CLOUD_API_KEY=
# ------------------------------------------------------------------------------
# Observability (optional)
# ------------------------------------------------------------------------------
# LANGSMITH_TRACING=true
# LANGSMITH_ENDPOINT=https://api.smith.langchain.com
# LANGSMITH_API_KEY=
# LANGSMITH_PROJECT=surfsense
# ------------------------------------------------------------------------------
# Advanced (optional)
# ------------------------------------------------------------------------------
# Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m
# JWT token lifetimes
# ACCESS_TOKEN_LIFETIME_SECONDS=86400
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600
# Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500
# Connector indexing lock TTL in seconds (default: 28800 = 8 hours)
# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800
# Residential proxy for web crawling
# RESIDENTIAL_PROXY_USERNAME=
# RESIDENTIAL_PROXY_PASSWORD=
# RESIDENTIAL_PROXY_HOSTNAME=
# RESIDENTIAL_PROXY_LOCATION=
# RESIDENTIAL_PROXY_TYPE=1

View file

@ -0,0 +1,206 @@
# =============================================================================
# SurfSense — Development Docker Compose
# =============================================================================
# Usage (from repo root):
# docker compose -f docker/docker-compose.dev.yml up --build
#
# This file builds from source and includes dev tools like pgAdmin.
# For production with prebuilt images, use docker/docker-compose.yml instead.
# =============================================================================
name: surfsense
services:
db:
image: pgvector/pgvector:pg17
ports:
- "${POSTGRES_PORT:-5432}:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
environment:
- POSTGRES_USER=${DB_USER:-postgres}
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
- POSTGRES_DB=${DB_NAME:-surfsense}
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
command: postgres -c config_file=/etc/postgresql/postgresql.conf
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"]
interval: 10s
timeout: 5s
retries: 5
pgadmin:
image: dpage/pgadmin4
ports:
- "${PGADMIN_PORT:-5050}:80"
environment:
- PGADMIN_DEFAULT_EMAIL=${PGADMIN_DEFAULT_EMAIL:-admin@surfsense.com}
- PGADMIN_DEFAULT_PASSWORD=${PGADMIN_DEFAULT_PASSWORD:-surfsense}
volumes:
- pgadmin_data:/var/lib/pgadmin
depends_on:
- db
redis:
image: redis:8-alpine
ports:
- "${REDIS_PORT:-6379}:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
backend:
build: ../surfsense_backend
ports:
- "${BACKEND_PORT:-8000}:8000"
volumes:
- ../surfsense_backend/app:/app/app
- shared_temp:/shared_tmp
env_file:
- ../surfsense_backend/.env
environment:
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
- CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
- CELERY_TASK_DEFAULT_QUEUE=surfsense
- PYTHONPATH=/app
- UVICORN_LOOP=asyncio
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
- LANGCHAIN_TRACING_V2=false
- LANGSMITH_TRACING=false
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
# Daytona Sandbox uncomment and set credentials to enable cloud code execution
# - DAYTONA_SANDBOX_ENABLED=TRUE
# - DAYTONA_API_KEY=${DAYTONA_API_KEY:-}
# - DAYTONA_API_URL=${DAYTONA_API_URL:-https://app.daytona.io/api}
# - DAYTONA_TARGET=${DAYTONA_TARGET:-us}
- SERVICE_ROLE=api
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 15s
timeout: 5s
retries: 30
start_period: 200s
celery_worker:
build: ../surfsense_backend
volumes:
- ../surfsense_backend/app:/app/app
- shared_temp:/shared_tmp
env_file:
- ../surfsense_backend/.env
environment:
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
- CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
- CELERY_TASK_DEFAULT_QUEUE=surfsense
- PYTHONPATH=/app
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
- SERVICE_ROLE=worker
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
backend:
condition: service_healthy
celery_beat:
build: ../surfsense_backend
env_file:
- ../surfsense_backend/.env
environment:
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
- CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
- CELERY_TASK_DEFAULT_QUEUE=surfsense
- PYTHONPATH=/app
- SERVICE_ROLE=beat
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
celery_worker:
condition: service_started
# flower:
# build: ../surfsense_backend
# ports:
# - "${FLOWER_PORT:-5555}:5555"
# env_file:
# - ../surfsense_backend/.env
# environment:
# - CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
# - CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
# - PYTHONPATH=/app
# command: celery -A app.celery_app flower --port=5555
# depends_on:
# - redis
# - celery_worker
electric:
image: electricsql/electric:1.4.10
ports:
- "${ELECTRIC_PORT:-5133}:3000"
# depends_on:
# - db
environment:
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
- ELECTRIC_INSECURE=true
- ELECTRIC_WRITE_TO_PG_MODE=direct
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
interval: 10s
timeout: 5s
retries: 5
frontend:
build:
context: ../surfsense_web
args:
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
ports:
- "${FRONTEND_PORT:-3000}:3000"
env_file:
- ../surfsense_web/.env
depends_on:
backend:
condition: service_healthy
electric:
condition: service_healthy
volumes:
postgres_data:
name: surfsense-postgres
pgadmin_data:
name: surfsense-pgadmin
redis_data:
name: surfsense-redis
shared_temp:
name: surfsense-shared-temp

195
docker/docker-compose.yml Normal file
View file

@ -0,0 +1,195 @@
# =============================================================================
# SurfSense — Production Docker Compose
# Docs: https://docs.surfsense.com/docs/docker-installation
# =============================================================================
# Usage:
# 1. Copy .env.example to .env and edit the required values
# 2. docker compose up -d
# =============================================================================
name: surfsense
services:
db:
image: pgvector/pgvector:pg17
volumes:
- postgres_data:/var/lib/postgresql/data
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
environment:
POSTGRES_USER: ${DB_USER:-surfsense}
POSTGRES_PASSWORD: ${DB_PASSWORD:-surfsense}
POSTGRES_DB: ${DB_NAME:-surfsense}
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
command: postgres -c config_file=/etc/postgresql/postgresql.conf
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-surfsense} -d ${DB_NAME:-surfsense}"]
interval: 10s
timeout: 5s
retries: 5
redis:
image: redis:8-alpine
volumes:
- redis_data:/data
command: redis-server --appendonly yes
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
backend:
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
ports:
- "${BACKEND_PORT:-8000}:8000"
volumes:
- shared_temp:/shared_tmp
env_file:
- .env
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
CELERY_TASK_DEFAULT_QUEUE: surfsense
PYTHONPATH: /app
UVICORN_LOOP: asyncio
UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3000}}
# Daytona Sandbox uncomment and set credentials to enable cloud code execution
# DAYTONA_SANDBOX_ENABLED: "TRUE"
# DAYTONA_API_KEY: ${DAYTONA_API_KEY:-}
# DAYTONA_API_URL: ${DAYTONA_API_URL:-https://app.daytona.io/api}
# DAYTONA_TARGET: ${DAYTONA_TARGET:-us}
SERVICE_ROLE: api
labels:
- "com.centurylinklabs.watchtower.enable=true"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 15s
timeout: 5s
retries: 30
start_period: 200s
celery_worker:
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
volumes:
- shared_temp:/shared_tmp
env_file:
- .env
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
CELERY_TASK_DEFAULT_QUEUE: surfsense
PYTHONPATH: /app
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
SERVICE_ROLE: worker
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
backend:
condition: service_healthy
labels:
- "com.centurylinklabs.watchtower.enable=true"
restart: unless-stopped
celery_beat:
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
env_file:
- .env
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
CELERY_TASK_DEFAULT_QUEUE: surfsense
PYTHONPATH: /app
SERVICE_ROLE: beat
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
celery_worker:
condition: service_started
labels:
- "com.centurylinklabs.watchtower.enable=true"
restart: unless-stopped
# flower:
# image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
# ports:
# - "${FLOWER_PORT:-5555}:5555"
# env_file:
# - .env
# environment:
# CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
# CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
# PYTHONPATH: /app
# command: celery -A app.celery_app flower --port=5555
# depends_on:
# - redis
# - celery_worker
# restart: unless-stopped
electric:
image: electricsql/electric:1.4.10
ports:
- "${ELECTRIC_PORT:-5133}:3000"
environment:
DATABASE_URL: ${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
ELECTRIC_INSECURE: "true"
ELECTRIC_WRITE_TO_PG_MODE: direct
restart: unless-stopped
depends_on:
db:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
interval: 10s
timeout: 5s
retries: 5
frontend:
image: ghcr.io/modsetter/surfsense-web:${SURFSENSE_VERSION:-latest}
ports:
- "${FRONTEND_PORT:-3000}:3000"
environment:
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8000}}
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:${ELECTRIC_PORT:-5133}}
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
labels:
- "com.centurylinklabs.watchtower.enable=true"
depends_on:
backend:
condition: service_healthy
electric:
condition: service_healthy
restart: unless-stopped
volumes:
postgres_data:
name: surfsense-postgres
redis_data:
name: surfsense-redis
shared_temp:
name: surfsense-shared-temp

View file

@ -1,26 +1,9 @@
#!/bin/sh
# ============================================================================
# Electric SQL User Initialization Script (docker-compose only)
# ============================================================================
# This script is ONLY used when running via docker-compose.
#
# How it works:
# - docker-compose.yml mounts this script into the PostgreSQL container's
# /docker-entrypoint-initdb.d/ directory
# - PostgreSQL automatically executes scripts in that directory on first
# container initialization
#
# For local PostgreSQL users (non-Docker), this script is NOT used.
# Instead, the Electric user is created by Alembic migration 66
# (66_add_notifications_table_and_electric_replication.py).
#
# Both approaches are idempotent (use IF NOT EXISTS), so running both
# will not cause conflicts.
# ============================================================================
# Creates the Electric SQL replication user on first DB initialization.
# Idempotent — safe to run alongside Alembic migration 66.
set -e
# Use environment variables with defaults
ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
@ -43,7 +26,6 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-E
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
-- Create the publication for Electric SQL (if not exists)
DO \$\$
BEGIN
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN

350
docker/scripts/install.ps1 Normal file
View file

@ -0,0 +1,350 @@
# =============================================================================
# SurfSense — One-line Install Script (Windows / PowerShell)
#
#
# Usage: irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
#
# To pass flags, save and run locally:
# .\install.ps1 -NoWatchtower
# .\install.ps1 -WatchtowerInterval 3600
#
# Handles two cases automatically:
# 1. Fresh install — no prior SurfSense data detected
# 2. Migration from the legacy all-in-one container (surfsense-data volume)
# Downloads and runs migrate-database.sh --yes, then restores the dump
# into the new PostgreSQL 17 stack. The user runs one command for both.
# =============================================================================
param(
[switch]$NoWatchtower,
[int]$WatchtowerInterval = 86400
)
$ErrorActionPreference = 'Stop'
# ── Configuration ───────────────────────────────────────────────────────────
$RepoRaw = "https://raw.githubusercontent.com/MODSetter/SurfSense/main"
$InstallDir = ".\surfsense"
$OldVolume = "surfsense-data"
$DumpFile = ".\surfsense_migration_backup.sql"
$KeyFile = ".\surfsense_migration_secret.key"
$MigrationDoneFile = "$InstallDir\.migration_done"
$MigrationMode = $false
$SetupWatchtower = -not $NoWatchtower
$WatchtowerContainer = "watchtower"
# ── Output helpers ──────────────────────────────────────────────────────────
function Write-Info { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Cyan -NoNewline; Write-Host $Msg }
function Write-Ok { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Green -NoNewline; Write-Host $Msg }
function Write-Warn { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Yellow -NoNewline; Write-Host $Msg }
function Write-Step { param([string]$Msg) Write-Host "`n-- $Msg" -ForegroundColor Cyan }
function Write-Err { param([string]$Msg) Write-Host "[SurfSense] ERROR: $Msg" -ForegroundColor Red; exit 1 }
function Invoke-NativeSafe {
param([scriptblock]$Command)
$previousErrorActionPreference = $ErrorActionPreference
try {
$ErrorActionPreference = 'Continue'
& $Command
} finally {
$ErrorActionPreference = $previousErrorActionPreference
}
}
# ── Pre-flight checks ──────────────────────────────────────────────────────
Write-Step "Checking prerequisites"
if (-not (Get-Command docker -ErrorAction SilentlyContinue)) {
Write-Err "Docker is not installed. Install Docker Desktop: https://docs.docker.com/desktop/install/windows-install/"
}
Write-Ok "Docker found."
Invoke-NativeSafe { docker info *>$null } | Out-Null
if ($LASTEXITCODE -ne 0) {
Write-Err "Docker daemon is not running. Please start Docker Desktop and try again."
}
Write-Ok "Docker daemon is running."
Invoke-NativeSafe { docker compose version *>$null } | Out-Null
if ($LASTEXITCODE -ne 0) {
Write-Err "Docker Compose is not available. It should be bundled with Docker Desktop."
}
Write-Ok "Docker Compose found."
# ── Wait-for-postgres helper ────────────────────────────────────────────────
function Wait-ForPostgres {
param([string]$DbUser)
$maxAttempts = 45
$attempt = 0
Write-Info "Waiting for PostgreSQL to accept connections..."
do {
$attempt++
if ($attempt -ge $maxAttempts) {
Write-Err "PostgreSQL did not become ready after $($maxAttempts * 2) seconds.`nCheck logs: cd $InstallDir; docker compose logs db"
}
Start-Sleep -Seconds 2
Push-Location $InstallDir
Invoke-NativeSafe { docker compose exec -T db pg_isready -U $DbUser -q *>$null } | Out-Null
$ready = $LASTEXITCODE -eq 0
Pop-Location
} while (-not $ready)
Write-Ok "PostgreSQL is ready."
}
# ── Download files ──────────────────────────────────────────────────────────
Write-Step "Downloading SurfSense files"
Write-Info "Installation directory: $InstallDir"
New-Item -ItemType Directory -Path "$InstallDir\scripts" -Force | Out-Null
$Files = @(
@{ Src = "docker/docker-compose.yml"; Dest = "docker-compose.yml" }
@{ Src = "docker/.env.example"; Dest = ".env.example" }
@{ Src = "docker/postgresql.conf"; Dest = "postgresql.conf" }
@{ Src = "docker/scripts/init-electric-user.sh"; Dest = "scripts/init-electric-user.sh" }
@{ Src = "docker/scripts/migrate-database.ps1"; Dest = "scripts/migrate-database.ps1" }
)
foreach ($f in $Files) {
$destPath = Join-Path $InstallDir $f.Dest
Write-Info "Downloading $($f.Dest)..."
try {
Invoke-WebRequest -Uri "$RepoRaw/$($f.Src)" -OutFile $destPath -UseBasicParsing
} catch {
Write-Err "Failed to download $($f.Dest). Check your internet connection and try again."
}
}
Write-Ok "All files downloaded to $InstallDir/"
# ── Legacy all-in-one detection ─────────────────────────────────────────────
$volumeList = Invoke-NativeSafe { docker volume ls --format '{{.Name}}' 2>$null }
if (($volumeList -split "`n") -contains $OldVolume -and -not (Test-Path $MigrationDoneFile)) {
$MigrationMode = $true
if (Test-Path $DumpFile) {
Write-Step "Migration mode - using existing dump (skipping extraction)"
Write-Info "Found existing dump: $DumpFile"
Write-Info "Skipping data extraction - proceeding directly to restore."
Write-Info "To force a fresh extraction, remove the dump first: Remove-Item $DumpFile"
} else {
Write-Step "Migration mode - legacy all-in-one container detected"
Write-Warn "Volume '$OldVolume' found. Your data will be migrated automatically."
Write-Warn "PostgreSQL is being upgraded from version 14 to 17."
Write-Warn "Your original data will NOT be deleted."
Write-Host ""
Write-Info "Running data extraction (migrate-database.ps1 -Yes)..."
Write-Info "Full extraction log: ./surfsense-migration.log"
Write-Host ""
$migrateScript = Join-Path $InstallDir "scripts/migrate-database.ps1"
& $migrateScript -Yes
if ($LASTEXITCODE -ne 0) {
Write-Err "Data extraction failed. See ./surfsense-migration.log for details.`nYou can also run migrate-database.ps1 manually with custom flags."
}
Write-Host ""
Write-Ok "Data extraction complete. Proceeding with installation and restore."
}
}
# ── Set up .env ─────────────────────────────────────────────────────────────
Write-Step "Configuring environment"
$envPath = Join-Path $InstallDir ".env"
$envExamplePath = Join-Path $InstallDir ".env.example"
if (-not (Test-Path $envPath)) {
Copy-Item $envExamplePath $envPath
if ($MigrationMode -and (Test-Path $KeyFile)) {
$SecretKey = (Get-Content $KeyFile -Raw).Trim()
Write-Ok "Using SECRET_KEY recovered from legacy container."
} else {
$bytes = New-Object byte[] 32
$rng = [System.Security.Cryptography.RNGCryptoServiceProvider]::new()
$rng.GetBytes($bytes)
$rng.Dispose()
$SecretKey = [Convert]::ToBase64String($bytes)
Write-Ok "Generated new random SECRET_KEY."
}
$content = Get-Content $envPath -Raw
$content = $content -replace 'SECRET_KEY=replace_me_with_a_random_string', "SECRET_KEY=$SecretKey"
Set-Content -Path $envPath -Value $content -NoNewline
Write-Info "Created $envPath"
} else {
Write-Warn ".env already exists - keeping your existing configuration."
}
# ── Start containers ────────────────────────────────────────────────────────
if ($MigrationMode) {
$envContent = Get-Content $envPath
$DbUser = ($envContent | Select-String '^DB_USER=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
$DbPass = ($envContent | Select-String '^DB_PASSWORD=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
$DbName = ($envContent | Select-String '^DB_NAME=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
if (-not $DbUser) { $DbUser = "surfsense" }
if (-not $DbPass) { $DbPass = "surfsense" }
if (-not $DbName) { $DbName = "surfsense" }
Write-Step "Starting PostgreSQL 17"
Push-Location $InstallDir
Invoke-NativeSafe { docker compose up -d db } | Out-Null
Pop-Location
Wait-ForPostgres -DbUser $DbUser
Write-Step "Restoring database"
if (-not (Test-Path $DumpFile)) {
Write-Err "Dump file '$DumpFile' not found. The migration script may have failed."
}
$DumpFilePath = (Resolve-Path $DumpFile).Path
Write-Info "Restoring dump into PostgreSQL 17 - this may take a while for large databases..."
$restoreErrFile = Join-Path $env:TEMP "surfsense_restore_err.log"
Push-Location $InstallDir
Invoke-NativeSafe { Get-Content -LiteralPath $DumpFilePath | docker compose exec -T -e "PGPASSWORD=$DbPass" db psql -U $DbUser -d $DbName 2>$restoreErrFile | Out-Null } | Out-Null
Pop-Location
$fatalErrors = @()
if (Test-Path $restoreErrFile) {
$fatalErrors = Get-Content $restoreErrFile |
Where-Object { $_ -match '^ERROR:' } |
Where-Object { $_ -notmatch 'already exists' } |
Where-Object { $_ -notmatch 'multiple primary keys' }
}
if ($fatalErrors.Count -gt 0) {
Write-Warn "Restore completed with errors (may be harmless pg_dump header noise):"
$fatalErrors | ForEach-Object { Write-Host $_ }
Write-Warn "If SurfSense behaves incorrectly, inspect manually."
} else {
Write-Ok "Database restored with no fatal errors."
}
# Smoke test
Push-Location $InstallDir
$tableCount = (Invoke-NativeSafe { docker compose exec -T -e "PGPASSWORD=$DbPass" db psql -U $DbUser -d $DbName -t -c "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';" 2>$null }).Trim()
Pop-Location
if (-not $tableCount -or $tableCount -eq "0") {
Write-Warn "Smoke test: no tables found after restore."
Write-Warn "The restore may have failed silently. Check: cd $InstallDir; docker compose logs db"
} else {
Write-Ok "Smoke test passed: $tableCount table(s) restored successfully."
New-Item -Path $MigrationDoneFile -ItemType File -Force | Out-Null
}
Write-Step "Starting all SurfSense services"
Push-Location $InstallDir
Invoke-NativeSafe { docker compose up -d }
Pop-Location
Write-Ok "All services started."
Remove-Item $KeyFile -ErrorAction SilentlyContinue
} else {
Write-Step "Starting SurfSense"
Push-Location $InstallDir
Invoke-NativeSafe { docker compose up -d }
Pop-Location
Write-Ok "All services started."
}
# ── Watchtower (auto-update) ────────────────────────────────────────────────
if ($SetupWatchtower) {
$wtHours = [math]::Floor($WatchtowerInterval / 3600)
Write-Step "Setting up Watchtower (auto-updates every ${wtHours}h)"
$wtState = Invoke-NativeSafe { docker inspect -f '{{.State.Running}}' $WatchtowerContainer 2>$null }
if ($LASTEXITCODE -ne 0) { $wtState = "missing" }
if ($wtState -eq "true") {
Write-Ok "Watchtower is already running - skipping."
} else {
if ($wtState -ne "missing") {
Write-Info "Removing stopped Watchtower container..."
Invoke-NativeSafe { docker rm -f $WatchtowerContainer *>$null } | Out-Null
}
Invoke-NativeSafe {
docker run -d `
--name $WatchtowerContainer `
--restart unless-stopped `
-v /var/run/docker.sock:/var/run/docker.sock `
nickfedor/watchtower `
--label-enable `
--interval $WatchtowerInterval *>$null
} | Out-Null
if ($LASTEXITCODE -eq 0) {
Write-Ok "Watchtower started - labeled SurfSense containers will auto-update."
} else {
Write-Warn "Could not start Watchtower. You can set it up manually or use: docker compose pull; docker compose up -d"
}
}
} else {
Write-Info "Skipping Watchtower setup (-NoWatchtower flag)."
}
# ── Done ────────────────────────────────────────────────────────────────────
Write-Host ""
Write-Host @"
.d8888b. .d888 .d8888b.
d88P Y88b d88P" d88P Y88b
Y88b. 888 Y88b.
"Y888b. 888 888 888d888 888888 "Y888b. .d88b. 88888b. .d8888b .d88b.
"Y88b. 888 888 888P" 888 "Y88b. d8P Y8b 888 "88b 88K d8P Y8b
"888 888 888 888 888 "888 88888888 888 888 "Y8888b. 88888888
Y88b d88P Y88b 888 888 888 Y88b d88P Y8b. 888 888 X88 Y8b.
"Y8888P" "Y88888 888 888 "Y8888P" "Y8888 888 888 88888P' "Y8888
"@ -ForegroundColor White
$versionDisplay = (Get-Content $envPath | Select-String '^SURFSENSE_VERSION=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
if (-not $versionDisplay) { $versionDisplay = "latest" }
Write-Host " OSS Alternative to NotebookLM for Teams [$versionDisplay]" -ForegroundColor Yellow
Write-Host ("=" * 62) -ForegroundColor Cyan
Write-Host ""
Write-Info " Frontend: http://localhost:3000"
Write-Info " Backend: http://localhost:8000"
Write-Info " API Docs: http://localhost:8000/docs"
Write-Info ""
Write-Info " Config: $InstallDir\.env"
Write-Info " Logs: cd $InstallDir; docker compose logs -f"
Write-Info " Stop: cd $InstallDir; docker compose down"
Write-Info " Update: cd $InstallDir; docker compose pull; docker compose up -d"
Write-Info ""
if ($SetupWatchtower) {
Write-Info " Watchtower: auto-updates every ${wtHours}h (stop: docker rm -f $WatchtowerContainer)"
} else {
Write-Warn " Watchtower skipped. For auto-updates, re-run without -NoWatchtower."
}
Write-Info ""
if ($MigrationMode) {
Write-Warn " Migration complete! Open frontend and verify your data."
Write-Warn " Once verified, clean up the legacy volume and migration files:"
Write-Warn " docker volume rm $OldVolume"
Write-Warn " Remove-Item $DumpFile"
Write-Warn " Remove-Item $MigrationDoneFile"
} else {
Write-Warn " First startup may take a few minutes while images are pulled."
Write-Warn " Edit $InstallDir\.env to configure API keys, OAuth, etc."
}

337
docker/scripts/install.sh Normal file
View file

@ -0,0 +1,337 @@
#!/usr/bin/env bash
# =============================================================================
# SurfSense — One-line Install Script
#
#
# Usage: curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
#
# Flags:
# --no-watchtower Skip automatic Watchtower setup
# --watchtower-interval=SECS Check interval in seconds (default: 86400 = 24h)
#
# Handles two cases automatically:
# 1. Fresh install — no prior SurfSense data detected
# 2. Migration from the legacy all-in-one container (surfsense-data volume)
# Downloads and runs migrate-database.sh --yes, then restores the dump
# into the new PostgreSQL 17 stack. The user runs one command for both.
#
# If you used custom database credentials in the old all-in-one container, run
# migrate-database.sh manually first (with --db-user / --db-password flags),
# then re-run this script:
# curl -fsSL .../docker/scripts/migrate-database.sh | bash -s -- --db-user X --db-password Y
# =============================================================================
set -euo pipefail
main() {
REPO_RAW="https://raw.githubusercontent.com/MODSetter/SurfSense/main"
INSTALL_DIR="./surfsense"
OLD_VOLUME="surfsense-data"
DUMP_FILE="./surfsense_migration_backup.sql"
KEY_FILE="./surfsense_migration_secret.key"
MIGRATION_DONE_FILE="${INSTALL_DIR}/.migration_done"
MIGRATION_MODE=false
SETUP_WATCHTOWER=true
WATCHTOWER_INTERVAL=86400
WATCHTOWER_CONTAINER="watchtower"
# ── Parse flags ─────────────────────────────────────────────────────────────
for arg in "$@"; do
case "$arg" in
--no-watchtower) SETUP_WATCHTOWER=false ;;
--watchtower-interval=*) WATCHTOWER_INTERVAL="${arg#*=}" ;;
esac
done
CYAN='\033[1;36m'
YELLOW='\033[1;33m'
GREEN='\033[0;32m'
RED='\033[0;31m'
BOLD='\033[1m'
NC='\033[0m'
info() { printf "${CYAN}[SurfSense]${NC} %s\n" "$1"; }
success() { printf "${GREEN}[SurfSense]${NC} %s\n" "$1"; }
warn() { printf "${YELLOW}[SurfSense]${NC} %s\n" "$1"; }
error() { printf "${RED}[SurfSense]${NC} ERROR: %s\n" "$1" >&2; exit 1; }
step() { printf "\n${BOLD}${CYAN}── %s${NC}\n" "$1"; }
# ── Pre-flight checks ────────────────────────────────────────────────────────
step "Checking prerequisites"
command -v docker >/dev/null 2>&1 \
|| error "Docker is not installed. Install it at: https://docs.docker.com/get-docker/"
success "Docker found."
docker info >/dev/null 2>&1 < /dev/null \
|| error "Docker daemon is not running. Please start Docker and try again."
success "Docker daemon is running."
if docker compose version >/dev/null 2>&1 < /dev/null; then
DC="docker compose"
elif command -v docker-compose >/dev/null 2>&1; then
DC="docker-compose"
else
error "Docker Compose is not installed. Install it at: https://docs.docker.com/compose/install/"
fi
success "Docker Compose found ($DC)."
# ── Wait-for-postgres helper ─────────────────────────────────────────────────
wait_for_pg() {
local db_user="$1"
local max_attempts=45
local attempt=0
info "Waiting for PostgreSQL to accept connections..."
until (cd "${INSTALL_DIR}" && ${DC} exec -T db pg_isready -U "${db_user}" -q 2>/dev/null) < /dev/null; do
attempt=$((attempt + 1))
if [[ $attempt -ge $max_attempts ]]; then
error "PostgreSQL did not become ready after $((max_attempts * 2)) seconds.\nCheck logs: cd ${INSTALL_DIR} && ${DC} logs db"
fi
printf "."
sleep 2
done
printf "\n"
success "PostgreSQL is ready."
}
# ── Download files ───────────────────────────────────────────────────────────
step "Downloading SurfSense files"
info "Installation directory: ${INSTALL_DIR}"
mkdir -p "${INSTALL_DIR}/scripts"
FILES=(
"docker/docker-compose.yml:docker-compose.yml"
"docker/.env.example:.env.example"
"docker/postgresql.conf:postgresql.conf"
"docker/scripts/init-electric-user.sh:scripts/init-electric-user.sh"
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
)
for entry in "${FILES[@]}"; do
src="${entry%%:*}"
dest="${entry##*:}"
info "Downloading ${dest}..."
curl -fsSL "${REPO_RAW}/${src}" -o "${INSTALL_DIR}/${dest}" \
|| error "Failed to download ${dest}. Check your internet connection and try again."
done
chmod +x "${INSTALL_DIR}/scripts/init-electric-user.sh"
chmod +x "${INSTALL_DIR}/scripts/migrate-database.sh"
success "All files downloaded to ${INSTALL_DIR}/"
# ── Legacy all-in-one detection ──────────────────────────────────────────────
# Detect surfsense-data volume → migration mode.
# If a dump already exists (from a previous partial run) skip extraction and
# go straight to restore — this makes re-runs safe and idempotent.
if docker volume ls --format '{{.Name}}' 2>/dev/null < /dev/null | grep -q "^${OLD_VOLUME}$" \
&& [[ ! -f "${MIGRATION_DONE_FILE}" ]]; then
MIGRATION_MODE=true
if [[ -f "${DUMP_FILE}" ]]; then
step "Migration mode — using existing dump (skipping extraction)"
info "Found existing dump: ${DUMP_FILE}"
info "Skipping data extraction — proceeding directly to restore."
info "To force a fresh extraction, remove the dump first: rm ${DUMP_FILE}"
else
step "Migration mode — legacy all-in-one container detected"
warn "Volume '${OLD_VOLUME}' found. Your data will be migrated automatically."
warn "PostgreSQL is being upgraded from version 14 to 17."
warn "Your original data will NOT be deleted."
printf "\n"
info "Running data extraction (migrate-database.sh --yes)..."
info "Full extraction log: ./surfsense-migration.log"
printf "\n"
# Run extraction non-interactively. On failure the error from
# migrate-database.sh is printed and install.sh exits here.
bash "${INSTALL_DIR}/scripts/migrate-database.sh" --yes < /dev/null \
|| error "Data extraction failed. See ./surfsense-migration.log for details.\nYou can also run migrate-database.sh manually with custom flags:\n bash ${INSTALL_DIR}/scripts/migrate-database.sh --db-user X --db-password Y"
printf "\n"
success "Data extraction complete. Proceeding with installation and restore."
fi
fi
# ── Set up .env ──────────────────────────────────────────────────────────────
step "Configuring environment"
if [ ! -f "${INSTALL_DIR}/.env" ]; then
cp "${INSTALL_DIR}/.env.example" "${INSTALL_DIR}/.env"
if $MIGRATION_MODE && [[ -f "${KEY_FILE}" ]]; then
SECRET_KEY=$(cat "${KEY_FILE}" | tr -d '[:space:]')
success "Using SECRET_KEY recovered from legacy container."
else
SECRET_KEY=$(openssl rand -base64 32 2>/dev/null \
|| head -c 32 /dev/urandom | base64 | tr -d '\n')
success "Generated new random SECRET_KEY."
fi
if [[ "$OSTYPE" == "darwin"* ]]; then
sed -i '' "s|SECRET_KEY=replace_me_with_a_random_string|SECRET_KEY=${SECRET_KEY}|" "${INSTALL_DIR}/.env"
else
sed -i "s|SECRET_KEY=replace_me_with_a_random_string|SECRET_KEY=${SECRET_KEY}|" "${INSTALL_DIR}/.env"
fi
info "Created ${INSTALL_DIR}/.env"
else
warn ".env already exists — keeping your existing configuration."
fi
# ── Start containers ─────────────────────────────────────────────────────────
if $MIGRATION_MODE; then
# Read DB credentials from .env (fall back to defaults from docker-compose.yml)
DB_USER=$(grep '^DB_USER=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
DB_PASS=$(grep '^DB_PASSWORD=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
DB_NAME=$(grep '^DB_NAME=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
DB_USER="${DB_USER:-surfsense}"
DB_PASS="${DB_PASS:-surfsense}"
DB_NAME="${DB_NAME:-surfsense}"
step "Starting PostgreSQL 17"
(cd "${INSTALL_DIR}" && ${DC} up -d db) < /dev/null
wait_for_pg "${DB_USER}"
step "Restoring database"
[[ -f "${DUMP_FILE}" ]] \
|| error "Dump file '${DUMP_FILE}' not found. The migration script may have failed.\n Check: ./surfsense-migration.log\n Or run manually: bash ${INSTALL_DIR}/scripts/migrate-database.sh --yes"
info "Restoring dump into PostgreSQL 17 — this may take a while for large databases..."
RESTORE_ERR="/tmp/surfsense_restore_err.log"
(cd "${INSTALL_DIR}" && ${DC} exec -T \
-e PGPASSWORD="${DB_PASS}" \
db psql -U "${DB_USER}" -d "${DB_NAME}" \
>/dev/null 2>"${RESTORE_ERR}") < "${DUMP_FILE}" || true
# Surface real errors; ignore benign "already exists" noise from pg_dump headers
FATAL_ERRORS=$(grep -i "^ERROR:" "${RESTORE_ERR}" \
| grep -iv "already exists" \
| grep -iv "multiple primary keys" \
|| true)
if [[ -n "${FATAL_ERRORS}" ]]; then
warn "Restore completed with errors (may be harmless pg_dump header noise):"
printf "%s\n" "${FATAL_ERRORS}"
warn "If SurfSense behaves incorrectly, inspect manually:"
warn " cd ${INSTALL_DIR} && ${DC} exec db psql -U ${DB_USER} -d ${DB_NAME} < ${DUMP_FILE}"
else
success "Database restored with no fatal errors."
fi
# Smoke test — verify tables are present
TABLE_COUNT=$(
cd "${INSTALL_DIR}" && ${DC} exec -T \
-e PGPASSWORD="${DB_PASS}" \
db psql -U "${DB_USER}" -d "${DB_NAME}" -t \
-c "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';" \
2>/dev/null < /dev/null | tr -d ' \n' || echo "0"
)
if [[ "${TABLE_COUNT}" == "0" || -z "${TABLE_COUNT}" ]]; then
warn "Smoke test: no tables found after restore."
warn "The restore may have failed silently. Check: cd ${INSTALL_DIR} && ${DC} logs db"
else
success "Smoke test passed: ${TABLE_COUNT} table(s) restored successfully."
touch "${MIGRATION_DONE_FILE}"
fi
step "Starting all SurfSense services"
(cd "${INSTALL_DIR}" && ${DC} up -d) < /dev/null
success "All services started."
# Key file is no longer needed — SECRET_KEY is now in .env
rm -f "${KEY_FILE}"
else
step "Starting SurfSense"
(cd "${INSTALL_DIR}" && ${DC} up -d) < /dev/null
success "All services started."
fi
# ── Watchtower (auto-update) ─────────────────────────────────────────────────
if $SETUP_WATCHTOWER; then
step "Setting up Watchtower (auto-updates every $((WATCHTOWER_INTERVAL / 3600))h)"
WT_STATE=$(docker inspect -f '{{.State.Running}}' "${WATCHTOWER_CONTAINER}" 2>/dev/null < /dev/null || echo "missing")
if [[ "${WT_STATE}" == "true" ]]; then
success "Watchtower is already running — skipping."
else
if [[ "${WT_STATE}" != "missing" ]]; then
info "Removing stopped Watchtower container..."
docker rm -f "${WATCHTOWER_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
fi
docker run -d \
--name "${WATCHTOWER_CONTAINER}" \
--restart unless-stopped \
-v /var/run/docker.sock:/var/run/docker.sock \
nickfedor/watchtower \
--label-enable \
--interval "${WATCHTOWER_INTERVAL}" >/dev/null 2>&1 < /dev/null \
&& success "Watchtower started — labeled SurfSense containers will auto-update." \
|| warn "Could not start Watchtower. You can set it up manually or use: docker compose pull && docker compose up -d"
fi
else
info "Skipping Watchtower setup (--no-watchtower flag)."
fi
# ── Done ─────────────────────────────────────────────────────────────────────
echo ""
printf '\033[1;37m'
cat << 'EOF'
.d8888b. .d888 .d8888b.
d88P Y88b d88P" d88P Y88b
Y88b. 888 Y88b.
"Y888b. 888 888 888d888 888888 "Y888b. .d88b. 88888b. .d8888b .d88b.
"Y88b. 888 888 888P" 888 "Y88b. d8P Y8b 888 "88b 88K d8P Y8b
"888 888 888 888 888 "888 88888888 888 888 "Y8888b. 88888888
Y88b d88P Y88b 888 888 888 Y88b d88P Y8b. 888 888 X88 Y8b.
"Y8888P" "Y88888 888 888 "Y8888P" "Y8888 888 888 88888P' "Y8888
EOF
_version_display=$(grep '^SURFSENSE_VERSION=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
_version_display="${_version_display:-latest}"
printf " OSS Alternative to NotebookLM for Teams ${YELLOW}[%s]${NC}\n" "${_version_display}"
printf "${CYAN}══════════════════════════════════════════════════════════════${NC}\n\n"
info " Frontend: http://localhost:3000"
info " Backend: http://localhost:8000"
info " API Docs: http://localhost:8000/docs"
info ""
info " Config: ${INSTALL_DIR}/.env"
info " Logs: cd ${INSTALL_DIR} && ${DC} logs -f"
info " Stop: cd ${INSTALL_DIR} && ${DC} down"
info " Update: cd ${INSTALL_DIR} && ${DC} pull && ${DC} up -d"
info ""
if $SETUP_WATCHTOWER; then
info " Watchtower: auto-updates every $((WATCHTOWER_INTERVAL / 3600))h (stop: docker rm -f ${WATCHTOWER_CONTAINER})"
else
warn " Watchtower skipped. For auto-updates, re-run without --no-watchtower."
fi
info ""
if $MIGRATION_MODE; then
warn " Migration complete! Open frontend and verify your data."
warn " Once verified, clean up the legacy volume and migration files:"
warn " docker volume rm ${OLD_VOLUME}"
warn " rm ${DUMP_FILE}"
warn " rm ${MIGRATION_DONE_FILE}"
else
warn " First startup may take a few minutes while images are pulled."
warn " Edit ${INSTALL_DIR}/.env to configure API keys, OAuth, etc."
fi
} # end main()
main "$@"

View file

@ -0,0 +1,343 @@
# =============================================================================
# SurfSense — Database Migration Script (Windows / PowerShell)
#
# Extracts data from the legacy all-in-one surfsense-data volume (PostgreSQL 14)
# and saves it as a SQL dump + SECRET_KEY file ready for install.ps1 to restore.
#
# Usage:
# .\migrate-database.ps1 [options]
#
# Options:
# -DbUser USER Old PostgreSQL username (default: surfsense)
# -DbPassword PASS Old PostgreSQL password (default: surfsense)
# -DbName NAME Old PostgreSQL database (default: surfsense)
# -Yes Skip all confirmation prompts
#
# Prerequisites:
# - Docker Desktop installed and running
# - The legacy surfsense-data volume must exist
# - ~500 MB free disk space for the dump file
#
# What this script does:
# 1. Stops any container using surfsense-data (to prevent corruption)
# 2. Starts a temporary PG14 container against the old volume
# 3. Dumps the database to .\surfsense_migration_backup.sql
# 4. Recovers the SECRET_KEY to .\surfsense_migration_secret.key
# 5. Exits — leaving installation to install.ps1
#
# What this script does NOT do:
# - Delete the original surfsense-data volume (do this manually after verifying)
# - Install the new SurfSense stack (install.ps1 handles that automatically)
#
# Note:
# install.ps1 downloads and runs this script automatically when it detects the
# legacy surfsense-data volume. You only need to run this script manually if
# you have custom database credentials (-DbUser / -DbPassword / -DbName)
# or if the automatic migration inside install.ps1 fails at the extraction step.
# =============================================================================
param(
[string]$DbUser = "surfsense",
[string]$DbPassword = "surfsense",
[string]$DbName = "surfsense",
[switch]$Yes
)
$ErrorActionPreference = 'Stop'
# ── Constants ────────────────────────────────────────────────────────────────
$OldVolume = "surfsense-data"
$TempContainer = "surfsense-pg14-migration"
$DumpFile = ".\surfsense_migration_backup.sql"
$KeyFile = ".\surfsense_migration_secret.key"
$PG14Image = "pgvector/pgvector:pg14"
$LogFile = ".\surfsense-migration.log"
# ── Output helpers ───────────────────────────────────────────────────────────
function Write-Info { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Cyan -NoNewline; Write-Host $Msg }
function Write-Ok { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Green -NoNewline; Write-Host $Msg }
function Write-Warn { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Yellow -NoNewline; Write-Host $Msg }
function Write-Step { param([string]$Step, [string]$Msg) Write-Host "`n-- Step ${Step}: $Msg" -ForegroundColor Cyan }
function Write-Err { param([string]$Msg) Write-Host "[SurfSense] ERROR: $Msg" -ForegroundColor Red; exit 1 }
function Log { param([string]$Msg) Add-Content -Path $LogFile -Value $Msg }
function Invoke-NativeSafe {
param([scriptblock]$Command)
$previousErrorActionPreference = $ErrorActionPreference
try {
$ErrorActionPreference = 'Continue'
& $Command
} finally {
$ErrorActionPreference = $previousErrorActionPreference
}
}
function Confirm-Action {
param([string]$Prompt)
if ($Yes) { return }
$reply = Read-Host "[SurfSense] $Prompt [y/N]"
if ($reply -notmatch '^[Yy]$') {
Write-Warn "Aborted."
exit 0
}
}
# ── Cleanup helper ───────────────────────────────────────────────────────────
function Remove-TempContainer {
$containers = Invoke-NativeSafe { docker ps -a --format '{{.Names}}' 2>$null }
if ($containers -and ($containers -split "`n") -contains $TempContainer) {
Write-Info "Cleaning up temporary container '$TempContainer'..."
Invoke-NativeSafe { docker stop $TempContainer *>$null } | Out-Null
Invoke-NativeSafe { docker rm $TempContainer *>$null } | Out-Null
}
}
# Register cleanup on script exit
Register-EngineEvent PowerShell.Exiting -Action {
$containers = Invoke-NativeSafe { docker ps -a --format '{{.Names}}' 2>$null }
if ($containers -and ($containers -split "`n") -contains "surfsense-pg14-migration") {
Invoke-NativeSafe { docker stop "surfsense-pg14-migration" *>$null } | Out-Null
Invoke-NativeSafe { docker rm "surfsense-pg14-migration" *>$null } | Out-Null
}
} | Out-Null
# ── Wait-for-postgres helper ────────────────────────────────────────────────
function Wait-ForPostgres {
param(
[string]$Container,
[string]$User,
[string]$Label = "PostgreSQL"
)
$maxAttempts = 45
$attempt = 0
Write-Info "Waiting for $Label to accept connections..."
do {
$attempt++
if ($attempt -ge $maxAttempts) {
Write-Err "$Label did not become ready after $($maxAttempts * 2) seconds. Check: docker logs $Container"
}
Start-Sleep -Seconds 2
Invoke-NativeSafe { docker exec $Container pg_isready -U $User -q 2>$null } | Out-Null
} while ($LASTEXITCODE -ne 0)
Write-Ok "$Label is ready."
}
Write-Info "Migrating data from legacy database (PostgreSQL 14 -> 17)"
"Migration started at $(Get-Date)" | Out-File $LogFile
# ── Step 0: Pre-flight checks ───────────────────────────────────────────────
Write-Step "0" "Pre-flight checks"
if (-not (Get-Command docker -ErrorAction SilentlyContinue)) {
Write-Err "Docker is not installed. Install Docker Desktop: https://docs.docker.com/desktop/install/windows-install/"
}
Invoke-NativeSafe { docker info *>$null } | Out-Null
if ($LASTEXITCODE -ne 0) {
Write-Err "Docker daemon is not running. Please start Docker Desktop and try again."
}
$volumeList = Invoke-NativeSafe { docker volume ls --format '{{.Name}}' 2>$null }
if (-not (($volumeList -split "`n") -contains $OldVolume)) {
Write-Err "Legacy volume '$OldVolume' not found. Are you sure you ran the old all-in-one SurfSense container?"
}
Write-Ok "Found legacy volume: $OldVolume"
$oldContainer = (Invoke-NativeSafe { docker ps --filter "volume=$OldVolume" --format '{{.Names}}' 2>$null } | Select-Object -First 1)
if ($oldContainer) {
Write-Warn "Container '$oldContainer' is running and using the '$OldVolume' volume."
Write-Warn "It must be stopped before migration to prevent data file corruption."
Confirm-Action "Stop '$oldContainer' now and proceed with data extraction?"
Invoke-NativeSafe { docker stop $oldContainer *>$null } | Out-Null
if ($LASTEXITCODE -ne 0) {
Write-Err "Failed to stop '$oldContainer'. Try: docker stop $oldContainer"
}
Write-Ok "Container '$oldContainer' stopped."
}
if (Test-Path $DumpFile) {
Write-Warn "Dump file '$DumpFile' already exists."
Write-Warn "If a previous extraction succeeded, just run install.ps1 now."
Write-Warn "To re-extract, remove the file first: Remove-Item $DumpFile"
Write-Err "Aborting to avoid overwriting an existing dump."
}
$staleContainers = Invoke-NativeSafe { docker ps -a --format '{{.Names}}' 2>$null }
if ($staleContainers -and ($staleContainers -split "`n") -contains $TempContainer) {
Write-Warn "Stale migration container '$TempContainer' found - removing it."
Invoke-NativeSafe { docker stop $TempContainer *>$null } | Out-Null
Invoke-NativeSafe { docker rm $TempContainer *>$null } | Out-Null
}
$drive = (Get-Item .).PSDrive
$freeMB = [math]::Floor($drive.Free / 1MB)
if ($freeMB -lt 500) {
Write-Warn "Low disk space: $freeMB MB free. At least 500 MB recommended for the dump."
Confirm-Action "Continue anyway?"
} else {
Write-Ok "Disk space: $freeMB MB free."
}
Write-Ok "All pre-flight checks passed."
# ── Confirmation prompt ──────────────────────────────────────────────────────
Write-Host ""
Write-Host "Extraction plan:" -ForegroundColor White
Write-Host " Source volume : " -NoNewline; Write-Host "$OldVolume" -ForegroundColor Yellow -NoNewline; Write-Host " (PG14 data at /data/postgres)"
Write-Host " Old credentials : user=" -NoNewline; Write-Host "$DbUser" -ForegroundColor Yellow -NoNewline; Write-Host " db=" -NoNewline; Write-Host "$DbName" -ForegroundColor Yellow
Write-Host " Dump saved to : " -NoNewline; Write-Host "$DumpFile" -ForegroundColor Yellow
Write-Host " SECRET_KEY to : " -NoNewline; Write-Host "$KeyFile" -ForegroundColor Yellow
Write-Host " Log file : " -NoNewline; Write-Host "$LogFile" -ForegroundColor Yellow
Write-Host ""
Confirm-Action "Start data extraction? (Your original data will not be deleted or modified.)"
# ── Step 1: Start temporary PostgreSQL 14 container ──────────────────────────
Write-Step "1" "Starting temporary PostgreSQL 14 container"
Write-Info "Pulling $PG14Image..."
Invoke-NativeSafe { docker pull $PG14Image *>$null } | Out-Null
if ($LASTEXITCODE -ne 0) {
Write-Warn "Could not pull $PG14Image - using cached image if available."
}
$dataUid = Invoke-NativeSafe { docker run --rm -v "${OldVolume}:/data" alpine stat -c '%u' /data/postgres 2>$null }
if (-not $dataUid -or $dataUid -eq "0") {
Write-Warn "Could not detect data directory UID - falling back to default (may chown files)."
$userFlag = @()
} else {
Write-Info "Data directory owned by UID $dataUid - starting temp container as that user."
$userFlag = @("--user", $dataUid)
}
$dockerRunArgs = @(
"run", "-d",
"--name", $TempContainer,
"-v", "${OldVolume}:/data",
"-e", "PGDATA=/data/postgres",
"-e", "POSTGRES_USER=$DbUser",
"-e", "POSTGRES_PASSWORD=$DbPassword",
"-e", "POSTGRES_DB=$DbName"
) + $userFlag + @($PG14Image)
Invoke-NativeSafe { docker @dockerRunArgs *>$null } | Out-Null
if ($LASTEXITCODE -ne 0) {
Write-Err "Failed to start temporary PostgreSQL 14 container."
}
Write-Ok "Temporary container '$TempContainer' started."
Wait-ForPostgres -Container $TempContainer -User $DbUser -Label "PostgreSQL 14"
# ── Step 2: Dump the database ────────────────────────────────────────────────
Write-Step "2" "Dumping PostgreSQL 14 database"
Write-Info "Running pg_dump - this may take a while for large databases..."
$pgDumpErrFile = Join-Path $env:TEMP "surfsense_pgdump_err.log"
Invoke-NativeSafe { docker exec -e "PGPASSWORD=$DbPassword" $TempContainer pg_dump -U $DbUser --no-password $DbName > $DumpFile 2>$pgDumpErrFile } | Out-Null
if ($LASTEXITCODE -ne 0) {
if (Test-Path $pgDumpErrFile) { Get-Content $pgDumpErrFile | Write-Host -ForegroundColor Red }
Remove-TempContainer
Write-Err "pg_dump failed. See above for details."
}
if (-not (Test-Path $DumpFile) -or (Get-Item $DumpFile).Length -eq 0) {
Remove-TempContainer
Write-Err "Dump file '$DumpFile' is empty. Something went wrong with pg_dump."
}
$dumpContent = (Get-Content $DumpFile -TotalCount 5) -join "`n"
if ($dumpContent -notmatch "PostgreSQL database dump") {
Remove-TempContainer
Write-Err "Dump file does not contain a valid PostgreSQL dump header - the file may be corrupt."
}
$dumpLines = (Get-Content $DumpFile | Measure-Object -Line).Lines
if ($dumpLines -lt 10) {
Remove-TempContainer
Write-Err "Dump has only $dumpLines lines - suspiciously small. Aborting."
}
$dumpSize = "{0:N1} MB" -f ((Get-Item $DumpFile).Length / 1MB)
Write-Ok "Dump complete: $dumpSize ($dumpLines lines) -> $DumpFile"
Write-Info "Stopping temporary PostgreSQL 14 container..."
Invoke-NativeSafe { docker stop $TempContainer *>$null } | Out-Null
Invoke-NativeSafe { docker rm $TempContainer *>$null } | Out-Null
Write-Ok "Temporary container removed."
# ── Step 3: Recover SECRET_KEY ───────────────────────────────────────────────
Write-Step "3" "Recovering SECRET_KEY"
$recoveredKey = ""
$keyCheck = Invoke-NativeSafe { docker run --rm -v "${OldVolume}:/data" alpine sh -c 'test -f /data/.secret_key && cat /data/.secret_key' 2>$null }
if ($LASTEXITCODE -eq 0 -and $keyCheck) {
$recoveredKey = $keyCheck.Trim()
Write-Ok "Recovered SECRET_KEY from '$OldVolume'."
} else {
Write-Warn "No SECRET_KEY file found at /data/.secret_key in '$OldVolume'."
Write-Warn "This means the all-in-one container was launched with SECRET_KEY set as an explicit env var."
if ($Yes) {
$bytes = New-Object byte[] 32
$rng = [System.Security.Cryptography.RNGCryptoServiceProvider]::new()
$rng.GetBytes($bytes)
$rng.Dispose()
$recoveredKey = [Convert]::ToBase64String($bytes)
Write-Warn "Non-interactive mode: generated a new SECRET_KEY automatically."
Write-Warn "All active browser sessions will be logged out after migration."
Write-Warn "To restore your original key, update SECRET_KEY in .\surfsense\.env afterwards."
} else {
Write-Warn "Enter the SECRET_KEY from your old container's environment"
$recoveredKey = Read-Host "[SurfSense] (press Enter to generate a new one - existing sessions will be invalidated)"
if (-not $recoveredKey) {
$bytes = New-Object byte[] 32
$rng = [System.Security.Cryptography.RNGCryptoServiceProvider]::new()
$rng.GetBytes($bytes)
$rng.Dispose()
$recoveredKey = [Convert]::ToBase64String($bytes)
Write-Warn "Generated a new SECRET_KEY. All active browser sessions will be logged out after migration."
}
}
}
Set-Content -Path $KeyFile -Value $recoveredKey -NoNewline
Write-Ok "SECRET_KEY saved to $KeyFile"
# ── Done ─────────────────────────────────────────────────────────────────────
Write-Host ""
Write-Host ("=" * 62) -ForegroundColor Green
Write-Host " Data extraction complete!" -ForegroundColor Green
Write-Host ("=" * 62) -ForegroundColor Green
Write-Host ""
Write-Ok "Dump file : $DumpFile ($dumpSize)"
Write-Ok "Secret key: $KeyFile"
Write-Host ""
Write-Info "Next step - run install.ps1 from this same directory:"
Write-Host ""
Write-Host " irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex" -ForegroundColor Cyan
Write-Host ""
Write-Info "install.ps1 will detect the dump, restore your data into PostgreSQL 17,"
Write-Info "and start the full SurfSense stack automatically."
Write-Host ""
Write-Warn "Keep both files until you have verified the migration:"
Write-Warn " $DumpFile"
Write-Warn " $KeyFile"
Write-Warn "Full log saved to: $LogFile"
Write-Host ""
Log "Migration extraction completed successfully at $(Get-Date)"

View file

@ -0,0 +1,335 @@
#!/usr/bin/env bash
# =============================================================================
# SurfSense — Database Migration Script
#
# Extracts data from the legacy all-in-one surfsense-data volume (PostgreSQL 14)
# and saves it as a SQL dump + SECRET_KEY file ready for install.sh to restore.
#
# Usage:
# bash migrate-database.sh [options]
#
# Options:
# --db-user USER Old PostgreSQL username (default: surfsense)
# --db-password PASS Old PostgreSQL password (default: surfsense)
# --db-name NAME Old PostgreSQL database (default: surfsense)
# --yes / -y Skip all confirmation prompts
# --help / -h Show this help
#
# Prerequisites:
# - Docker installed and running
# - The legacy surfsense-data volume must exist
# - ~500 MB free disk space for the dump file
#
# What this script does:
# 1. Stops any container using surfsense-data (to prevent corruption)
# 2. Starts a temporary PG14 container against the old volume
# 3. Dumps the database to ./surfsense_migration_backup.sql
# 4. Recovers the SECRET_KEY to ./surfsense_migration_secret.key
# 5. Exits — leaving installation to install.sh
#
# What this script does NOT do:
# - Delete the original surfsense-data volume (do this manually after verifying)
# - Install the new SurfSense stack (install.sh handles that automatically)
#
# Note:
# install.sh downloads and runs this script automatically when it detects the
# legacy surfsense-data volume. You only need to run this script manually if
# you have custom database credentials (--db-user / --db-password / --db-name)
# or if the automatic migration inside install.sh fails at the extraction step.
# =============================================================================
set -euo pipefail
# ── Colours ──────────────────────────────────────────────────────────────────
CYAN='\033[1;36m'
YELLOW='\033[1;33m'
GREEN='\033[0;32m'
RED='\033[0;31m'
BOLD='\033[1m'
NC='\033[0m'
# ── Logging — tee everything to a log file ───────────────────────────────────
LOG_FILE="./surfsense-migration.log"
exec > >(tee -a "${LOG_FILE}") 2>&1
# ── Output helpers ────────────────────────────────────────────────────────────
info() { printf "${CYAN}[SurfSense]${NC} %s\n" "$1"; }
success() { printf "${GREEN}[SurfSense]${NC} %s\n" "$1"; }
warn() { printf "${YELLOW}[SurfSense]${NC} %s\n" "$1"; }
error() { printf "${RED}[SurfSense]${NC} ERROR: %s\n" "$1" >&2; exit 1; }
step() { printf "\n${BOLD}${CYAN}── Step %s: %s${NC}\n" "$1" "$2"; }
# ── Constants ─────────────────────────────────────────────────────────────────
OLD_VOLUME="surfsense-data"
TEMP_CONTAINER="surfsense-pg14-migration"
DUMP_FILE="./surfsense_migration_backup.sql"
KEY_FILE="./surfsense_migration_secret.key"
PG14_IMAGE="pgvector/pgvector:pg14"
# ── Defaults ──────────────────────────────────────────────────────────────────
OLD_DB_USER="surfsense"
OLD_DB_PASSWORD="surfsense"
OLD_DB_NAME="surfsense"
AUTO_YES=false
# ── Argument parsing ──────────────────────────────────────────────────────────
while [[ $# -gt 0 ]]; do
case "$1" in
--db-user) OLD_DB_USER="$2"; shift 2 ;;
--db-password) OLD_DB_PASSWORD="$2"; shift 2 ;;
--db-name) OLD_DB_NAME="$2"; shift 2 ;;
--yes|-y) AUTO_YES=true; shift ;;
--help|-h)
grep '^#' "$0" | grep -v '^#!/' | sed 's/^# \{0,1\}//'
exit 0
;;
*) error "Unknown option: $1 — run with --help for usage." ;;
esac
done
# ── Confirmation helper ───────────────────────────────────────────────────────
confirm() {
if $AUTO_YES; then return 0; fi
printf "${YELLOW}[SurfSense]${NC} %s [y/N] " "$1"
read -r reply
[[ "$reply" =~ ^[Yy]$ ]] || { warn "Aborted."; exit 0; }
}
# ── Cleanup trap — always remove the temp container ──────────────────────────
cleanup() {
local exit_code=$?
if docker ps -a --format '{{.Names}}' 2>/dev/null < /dev/null | grep -q "^${TEMP_CONTAINER}$"; then
info "Cleaning up temporary container '${TEMP_CONTAINER}'..."
docker stop "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
docker rm "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
fi
if [[ $exit_code -ne 0 ]]; then
printf "\n${RED}[SurfSense]${NC} Migration data extraction failed (exit code %s).\n" "${exit_code}" >&2
printf "${RED}[SurfSense]${NC} Full log: %s\n" "${LOG_FILE}" >&2
printf "${YELLOW}[SurfSense]${NC} Your original data in '${OLD_VOLUME}' is untouched.\n" >&2
fi
}
trap cleanup EXIT
# ── Wait-for-postgres helper ──────────────────────────────────────────────────
wait_for_pg() {
local container="$1"
local user="$2"
local label="${3:-PostgreSQL}"
local max_attempts=45
local attempt=0
info "Waiting for ${label} to accept connections..."
until docker exec "${container}" pg_isready -U "${user}" -q 2>/dev/null < /dev/null; do
attempt=$((attempt + 1))
if [[ $attempt -ge $max_attempts ]]; then
error "${label} did not become ready after $((max_attempts * 2)) seconds. Check: docker logs ${container}"
fi
printf "."
sleep 2
done
printf "\n"
success "${label} is ready."
}
info "Migrating data from legacy database (PostgreSQL 14 → 17)"
# ── Step 0: Pre-flight checks ─────────────────────────────────────────────────
step "0" "Pre-flight checks"
# Docker CLI
command -v docker >/dev/null 2>&1 \
|| error "Docker is not installed. Install it at: https://docs.docker.com/get-docker/"
# Docker daemon
docker info >/dev/null 2>&1 < /dev/null \
|| error "Docker daemon is not running. Please start Docker and try again."
# Old volume must exist
docker volume ls --format '{{.Name}}' < /dev/null | grep -q "^${OLD_VOLUME}$" \
|| error "Legacy volume '${OLD_VOLUME}' not found.\n Are you sure you ran the old all-in-one SurfSense container?"
success "Found legacy volume: ${OLD_VOLUME}"
# Detect and stop any container currently using the old volume
# (mounting a live PG volume into a second container causes the new container's
# entrypoint to chown the data files, breaking the running container's access)
OLD_CONTAINER=$(docker ps --filter "volume=${OLD_VOLUME}" --format '{{.Names}}' < /dev/null | head -n1 || true)
if [[ -n "${OLD_CONTAINER}" ]]; then
warn "Container '${OLD_CONTAINER}' is running and using the '${OLD_VOLUME}' volume."
warn "It must be stopped before migration to prevent data file corruption."
confirm "Stop '${OLD_CONTAINER}' now and proceed with data extraction?"
docker stop "${OLD_CONTAINER}" >/dev/null 2>&1 < /dev/null \
|| error "Failed to stop '${OLD_CONTAINER}'. Try: docker stop ${OLD_CONTAINER}"
success "Container '${OLD_CONTAINER}' stopped."
fi
# Bail out if a dump already exists — don't overwrite a previous successful run
if [[ -f "${DUMP_FILE}" ]]; then
warn "Dump file '${DUMP_FILE}' already exists."
warn "If a previous extraction succeeded, just run install.sh now."
warn "To re-extract, remove the file first: rm ${DUMP_FILE}"
error "Aborting to avoid overwriting an existing dump."
fi
# Clean up any stale temp container from a previous failed run
if docker ps -a --format '{{.Names}}' < /dev/null | grep -q "^${TEMP_CONTAINER}$"; then
warn "Stale migration container '${TEMP_CONTAINER}' found — removing it."
docker stop "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
docker rm "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
fi
# Disk space (warn if < 500 MB free)
if command -v df >/dev/null 2>&1; then
FREE_KB=$(df -k . | awk 'NR==2 {print $4}')
FREE_MB=$(( FREE_KB / 1024 ))
if [[ $FREE_MB -lt 500 ]]; then
warn "Low disk space: ${FREE_MB} MB free. At least 500 MB recommended for the dump."
confirm "Continue anyway?"
else
success "Disk space: ${FREE_MB} MB free."
fi
fi
success "All pre-flight checks passed."
# ── Confirmation prompt ───────────────────────────────────────────────────────
printf "\n${BOLD}Extraction plan:${NC}\n"
printf " Source volume : ${YELLOW}%s${NC} (PG14 data at /data/postgres)\n" "${OLD_VOLUME}"
printf " Old credentials : user=${YELLOW}%s${NC} db=${YELLOW}%s${NC}\n" "${OLD_DB_USER}" "${OLD_DB_NAME}"
printf " Dump saved to : ${YELLOW}%s${NC}\n" "${DUMP_FILE}"
printf " SECRET_KEY to : ${YELLOW}%s${NC}\n" "${KEY_FILE}"
printf " Log file : ${YELLOW}%s${NC}\n\n" "${LOG_FILE}"
confirm "Start data extraction? (Your original data will not be deleted or modified.)"
# ── Step 1: Start temporary PostgreSQL 14 container ──────────────────────────
step "1" "Starting temporary PostgreSQL 14 container"
info "Pulling ${PG14_IMAGE}..."
docker pull "${PG14_IMAGE}" >/dev/null 2>&1 < /dev/null \
|| warn "Could not pull ${PG14_IMAGE} — using cached image if available."
# Detect the UID that owns the existing data files and run the temp container
# as that user. This prevents the official postgres image entrypoint from
# running as root and doing `chown -R postgres /data/postgres`, which would
# re-own the files to UID 999 and break any subsequent access by the original
# container's postgres process (which may run as a different UID).
DATA_UID=$(docker run --rm -v "${OLD_VOLUME}:/data" alpine \
stat -c '%u' /data/postgres 2>/dev/null < /dev/null || echo "")
if [[ -z "${DATA_UID}" || "${DATA_UID}" == "0" ]]; then
warn "Could not detect data directory UID — falling back to default (may chown files)."
USER_FLAG=""
else
info "Data directory owned by UID ${DATA_UID} — starting temp container as that user."
USER_FLAG="--user ${DATA_UID}"
fi
docker run -d \
--name "${TEMP_CONTAINER}" \
-v "${OLD_VOLUME}:/data" \
-e PGDATA=/data/postgres \
-e POSTGRES_USER="${OLD_DB_USER}" \
-e POSTGRES_PASSWORD="${OLD_DB_PASSWORD}" \
-e POSTGRES_DB="${OLD_DB_NAME}" \
${USER_FLAG} \
"${PG14_IMAGE}" >/dev/null < /dev/null
success "Temporary container '${TEMP_CONTAINER}' started."
wait_for_pg "${TEMP_CONTAINER}" "${OLD_DB_USER}" "PostgreSQL 14"
# ── Step 2: Dump the database ─────────────────────────────────────────────────
step "2" "Dumping PostgreSQL 14 database"
info "Running pg_dump — this may take a while for large databases..."
if ! docker exec \
-e PGPASSWORD="${OLD_DB_PASSWORD}" \
"${TEMP_CONTAINER}" \
pg_dump -U "${OLD_DB_USER}" --no-password "${OLD_DB_NAME}" \
> "${DUMP_FILE}" 2>/tmp/surfsense_pgdump_err < /dev/null; then
cat /tmp/surfsense_pgdump_err >&2
error "pg_dump failed. See above for details."
fi
# Validate: non-empty file
[[ -s "${DUMP_FILE}" ]] \
|| error "Dump file '${DUMP_FILE}' is empty. Something went wrong with pg_dump."
# Validate: looks like a real PG dump
grep -q "PostgreSQL database dump" "${DUMP_FILE}" \
|| error "Dump file does not contain a valid PostgreSQL dump header — the file may be corrupt."
# Validate: sanity-check line count
DUMP_LINES=$(wc -l < "${DUMP_FILE}" | tr -d ' ')
[[ $DUMP_LINES -ge 10 ]] \
|| error "Dump has only ${DUMP_LINES} lines — suspiciously small. Aborting."
DUMP_SIZE=$(du -sh "${DUMP_FILE}" 2>/dev/null | cut -f1)
success "Dump complete: ${DUMP_SIZE} (${DUMP_LINES} lines) → ${DUMP_FILE}"
# Stop the temp container (trap will also handle it on unexpected exit)
info "Stopping temporary PostgreSQL 14 container..."
docker stop "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
docker rm "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
success "Temporary container removed."
# ── Step 3: Recover SECRET_KEY ────────────────────────────────────────────────
step "3" "Recovering SECRET_KEY"
RECOVERED_KEY=""
if docker run --rm -v "${OLD_VOLUME}:/data" alpine \
sh -c 'test -f /data/.secret_key && cat /data/.secret_key' \
2>/dev/null < /dev/null | grep -q .; then
RECOVERED_KEY=$(
docker run --rm -v "${OLD_VOLUME}:/data" alpine \
cat /data/.secret_key 2>/dev/null < /dev/null | tr -d '[:space:]'
)
success "Recovered SECRET_KEY from '${OLD_VOLUME}'."
else
warn "No SECRET_KEY file found at /data/.secret_key in '${OLD_VOLUME}'."
warn "This means the all-in-one container was launched with SECRET_KEY set as an explicit env var."
if $AUTO_YES; then
# Non-interactive (called from install.sh) — auto-generate rather than hanging on read
RECOVERED_KEY=$(openssl rand -base64 32 2>/dev/null \
|| head -c 32 /dev/urandom | base64 | tr -d '\n')
warn "Non-interactive mode: generated a new SECRET_KEY automatically."
warn "All active browser sessions will be logged out after migration."
warn "To restore your original key, update SECRET_KEY in ./surfsense/.env afterwards."
else
printf "${YELLOW}[SurfSense]${NC} Enter the SECRET_KEY from your old container's environment\n"
printf "${YELLOW}[SurfSense]${NC} (press Enter to generate a new one — existing sessions will be invalidated): "
read -r RECOVERED_KEY
if [[ -z "${RECOVERED_KEY}" ]]; then
RECOVERED_KEY=$(openssl rand -base64 32 2>/dev/null \
|| head -c 32 /dev/urandom | base64 | tr -d '\n')
warn "Generated a new SECRET_KEY. All active browser sessions will be logged out after migration."
fi
fi
fi
# Save SECRET_KEY to a file for install.sh to pick up
printf '%s' "${RECOVERED_KEY}" > "${KEY_FILE}"
success "SECRET_KEY saved to ${KEY_FILE}"
# ── Done ──────────────────────────────────────────────────────────────────────
printf "\n${GREEN}${BOLD}"
printf "══════════════════════════════════════════════════════════════\n"
printf " Data extraction complete!\n"
printf "══════════════════════════════════════════════════════════════\n"
printf "${NC}\n"
success "Dump file : ${DUMP_FILE} (${DUMP_SIZE})"
success "Secret key: ${KEY_FILE}"
printf "\n"
info "Next step — run install.sh from this same directory:"
printf "\n"
printf "${CYAN} curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash${NC}\n"
printf "\n"
info "install.sh will detect the dump, restore your data into PostgreSQL 17,"
info "and start the full SurfSense stack automatically."
printf "\n"
warn "Keep both files until you have verified the migration:"
warn " ${DUMP_FILE}"
warn " ${KEY_FILE}"
warn "Full log saved to: ${LOG_FILE}"
printf "\n"

View file

@ -1,243 +0,0 @@
#!/bin/bash
set -e
echo "==========================================="
echo " 🏄 SurfSense All-in-One Container"
echo "==========================================="
# Create log directory
mkdir -p /var/log/supervisor
# ================================================
# Ensure data directory exists
# ================================================
mkdir -p /data
# ================================================
# Generate SECRET_KEY if not provided
# ================================================
if [ -z "$SECRET_KEY" ]; then
# Generate a random secret key and persist it
if [ -f /data/.secret_key ]; then
export SECRET_KEY=$(cat /data/.secret_key)
echo "✅ Using existing SECRET_KEY from persistent storage"
else
export SECRET_KEY=$(python3 -c "import secrets; print(secrets.token_urlsafe(32))")
echo "$SECRET_KEY" > /data/.secret_key
chmod 600 /data/.secret_key
echo "✅ Generated new SECRET_KEY (saved for persistence)"
fi
fi
# ================================================
# Set default TTS/STT services if not provided
# ================================================
if [ -z "$TTS_SERVICE" ]; then
export TTS_SERVICE="local/kokoro"
echo "✅ Using default TTS_SERVICE: local/kokoro"
fi
if [ -z "$STT_SERVICE" ]; then
export STT_SERVICE="local/base"
echo "✅ Using default STT_SERVICE: local/base"
fi
# ================================================
# Set Electric SQL configuration
# ================================================
export ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
export ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
if [ -z "$ELECTRIC_DATABASE_URL" ]; then
export ELECTRIC_DATABASE_URL="postgresql://${ELECTRIC_DB_USER}:${ELECTRIC_DB_PASSWORD}@localhost:5432/${POSTGRES_DB:-surfsense}?sslmode=disable"
echo "✅ Electric SQL URL configured dynamically"
else
# Ensure sslmode=disable is in the URL if not already present
if [[ "$ELECTRIC_DATABASE_URL" != *"sslmode="* ]]; then
# Add sslmode=disable (handle both cases: with or without existing query params)
if [[ "$ELECTRIC_DATABASE_URL" == *"?"* ]]; then
export ELECTRIC_DATABASE_URL="${ELECTRIC_DATABASE_URL}&sslmode=disable"
else
export ELECTRIC_DATABASE_URL="${ELECTRIC_DATABASE_URL}?sslmode=disable"
fi
fi
echo "✅ Electric SQL URL configured from environment"
fi
# Set Electric SQL port
export ELECTRIC_PORT="${ELECTRIC_PORT:-5133}"
export PORT="${ELECTRIC_PORT}"
# ================================================
# Initialize PostgreSQL if needed
# ================================================
if [ ! -f /data/postgres/PG_VERSION ]; then
echo "📦 Initializing PostgreSQL database..."
# Initialize PostgreSQL data directory
chown -R postgres:postgres /data/postgres
chmod 700 /data/postgres
# Initialize with UTF8 encoding (required for proper text handling)
su - postgres -c "/usr/lib/postgresql/14/bin/initdb -D /data/postgres --encoding=UTF8 --locale=C.UTF-8"
# Configure PostgreSQL for connections
echo "host all all 0.0.0.0/0 md5" >> /data/postgres/pg_hba.conf
echo "local all all trust" >> /data/postgres/pg_hba.conf
echo "listen_addresses='*'" >> /data/postgres/postgresql.conf
# Enable logical replication for Electric SQL
echo "wal_level = logical" >> /data/postgres/postgresql.conf
echo "max_replication_slots = 10" >> /data/postgres/postgresql.conf
echo "max_wal_senders = 10" >> /data/postgres/postgresql.conf
# Start PostgreSQL temporarily to create database and user
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres -l /tmp/postgres_init.log start"
# Wait for PostgreSQL to be ready
sleep 5
# Create user and database
su - postgres -c "psql -c \"CREATE USER ${POSTGRES_USER:-surfsense} WITH PASSWORD '${POSTGRES_PASSWORD:-surfsense}' SUPERUSER;\""
su - postgres -c "psql -c \"CREATE DATABASE ${POSTGRES_DB:-surfsense} OWNER ${POSTGRES_USER:-surfsense};\""
# Enable pgvector extension
su - postgres -c "psql -d ${POSTGRES_DB:-surfsense} -c 'CREATE EXTENSION IF NOT EXISTS vector;'"
# Create Electric SQL replication user (idempotent - uses IF NOT EXISTS)
echo "📡 Creating Electric SQL replication user..."
su - postgres -c "psql -d ${POSTGRES_DB:-surfsense} <<-EOSQL
DO \\\$\\\$
BEGIN
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '${ELECTRIC_DB_USER}') THEN
CREATE USER ${ELECTRIC_DB_USER} WITH REPLICATION PASSWORD '${ELECTRIC_DB_PASSWORD}';
END IF;
END
\\\$\\\$;
GRANT CONNECT ON DATABASE ${POSTGRES_DB:-surfsense} TO ${ELECTRIC_DB_USER};
GRANT USAGE ON SCHEMA public TO ${ELECTRIC_DB_USER};
GRANT SELECT ON ALL TABLES IN SCHEMA public TO ${ELECTRIC_DB_USER};
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO ${ELECTRIC_DB_USER};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO ${ELECTRIC_DB_USER};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO ${ELECTRIC_DB_USER};
-- Create the publication for Electric SQL (if not exists)
DO \\\$\\\$
BEGIN
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
CREATE PUBLICATION electric_publication_default;
END IF;
END
\\\$\\\$;
EOSQL"
echo "✅ Electric SQL user '${ELECTRIC_DB_USER}' created"
# Stop temporary PostgreSQL
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres stop"
echo "✅ PostgreSQL initialized successfully"
else
echo "✅ PostgreSQL data directory already exists"
fi
# ================================================
# Initialize Redis data directory
# ================================================
mkdir -p /data/redis
chmod 755 /data/redis
echo "✅ Redis data directory ready"
# ================================================
# Copy frontend build to runtime location
# ================================================
if [ -d /app/frontend/.next/standalone ]; then
cp -r /app/frontend/.next/standalone/* /app/frontend/ 2>/dev/null || true
cp -r /app/frontend/.next/static /app/frontend/.next/static 2>/dev/null || true
fi
# ================================================
# Runtime Environment Variable Replacement
# ================================================
# Next.js NEXT_PUBLIC_* vars are baked in at build time.
# This replaces placeholder values with actual runtime env vars.
echo "🔧 Applying runtime environment configuration..."
# Set defaults if not provided
NEXT_PUBLIC_FASTAPI_BACKEND_URL="${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}"
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE="${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}"
NEXT_PUBLIC_ETL_SERVICE="${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}"
NEXT_PUBLIC_ELECTRIC_URL="${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}"
NEXT_PUBLIC_ELECTRIC_AUTH_MODE="${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}"
NEXT_PUBLIC_DEPLOYMENT_MODE="${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}"
# Replace placeholders in all JS files
find /app/frontend -type f \( -name "*.js" -o -name "*.json" \) -exec sed -i \
-e "s|__NEXT_PUBLIC_FASTAPI_BACKEND_URL__|${NEXT_PUBLIC_FASTAPI_BACKEND_URL}|g" \
-e "s|__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__|${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}|g" \
-e "s|__NEXT_PUBLIC_ETL_SERVICE__|${NEXT_PUBLIC_ETL_SERVICE}|g" \
-e "s|__NEXT_PUBLIC_ELECTRIC_URL__|${NEXT_PUBLIC_ELECTRIC_URL}|g" \
-e "s|__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__|${NEXT_PUBLIC_ELECTRIC_AUTH_MODE}|g" \
-e "s|__NEXT_PUBLIC_DEPLOYMENT_MODE__|${NEXT_PUBLIC_DEPLOYMENT_MODE}|g" \
{} +
echo "✅ Environment configuration applied"
echo " Backend URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
echo " Electric URL: ${NEXT_PUBLIC_ELECTRIC_URL}"
echo " Deployment Mode: ${NEXT_PUBLIC_DEPLOYMENT_MODE}"
# ================================================
# Run database migrations
# ================================================
run_migrations() {
echo "🔄 Running database migrations..."
# Start PostgreSQL temporarily for migrations
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres -l /tmp/postgres_migrate.log start"
sleep 5
# Start Redis temporarily for migrations (some might need it)
redis-server --dir /data/redis --daemonize yes
sleep 2
# Run alembic migrations
cd /app/backend
alembic upgrade head || echo "⚠️ Migrations may have already been applied"
# Stop temporary services
redis-cli shutdown || true
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres stop"
echo "✅ Database migrations complete"
}
# Always run migrations on startup - alembic upgrade head is safe to run
# every time. It only applies pending migrations (never re-runs applied ones,
# never calls downgrade). This ensures updates are applied automatically.
run_migrations
# ================================================
# Environment Variables Info
# ================================================
echo ""
echo "==========================================="
echo " 📋 Configuration"
echo "==========================================="
echo " Frontend URL: http://localhost:3000"
echo " Backend API: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
echo " API Docs: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}/docs"
echo " Electric URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}"
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
echo " TTS Service: ${TTS_SERVICE}"
echo " STT Service: ${STT_SERVICE}"
echo "==========================================="
echo ""
# ================================================
# Start Supervisor (manages all services)
# ================================================
echo "🚀 Starting all services..."
exec /usr/local/bin/supervisord -c /etc/supervisor/conf.d/surfsense.conf

View file

@ -1,77 +0,0 @@
#!/bin/bash
# PostgreSQL initialization script for SurfSense
# This script is called during container startup if the database needs initialization
set -e
PGDATA=${PGDATA:-/data/postgres}
POSTGRES_USER=${POSTGRES_USER:-surfsense}
POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-surfsense}
POSTGRES_DB=${POSTGRES_DB:-surfsense}
# Electric SQL user credentials (configurable)
ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
echo "Initializing PostgreSQL..."
# Check if PostgreSQL is already initialized
if [ -f "$PGDATA/PG_VERSION" ]; then
echo "PostgreSQL data directory already exists. Skipping initialization."
exit 0
fi
# Initialize the database cluster
/usr/lib/postgresql/14/bin/initdb -D "$PGDATA" --username=postgres
# Configure PostgreSQL
cat >> "$PGDATA/postgresql.conf" << EOF
listen_addresses = '*'
max_connections = 200
shared_buffers = 256MB
# Enable logical replication (required for Electric SQL)
wal_level = logical
max_replication_slots = 10
max_wal_senders = 10
# Performance settings
checkpoint_timeout = 10min
max_wal_size = 1GB
min_wal_size = 80MB
EOF
cat >> "$PGDATA/pg_hba.conf" << EOF
# Allow connections from anywhere with password
host all all 0.0.0.0/0 md5
host all all ::0/0 md5
EOF
# Start PostgreSQL temporarily
/usr/lib/postgresql/14/bin/pg_ctl -D "$PGDATA" -l /tmp/postgres_init.log start
# Wait for PostgreSQL to start
sleep 3
# Create user and database
psql -U postgres << EOF
CREATE USER $POSTGRES_USER WITH PASSWORD '$POSTGRES_PASSWORD' SUPERUSER;
CREATE DATABASE $POSTGRES_DB OWNER $POSTGRES_USER;
\c $POSTGRES_DB
CREATE EXTENSION IF NOT EXISTS vector;
-- Create Electric SQL replication user
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
EOF
echo "PostgreSQL initialized successfully."
# Stop PostgreSQL (supervisor will start it)
/usr/lib/postgresql/14/bin/pg_ctl -D "$PGDATA" stop

View file

@ -1,121 +0,0 @@
[supervisord]
nodaemon=true
logfile=/dev/stdout
logfile_maxbytes=0
pidfile=/var/run/supervisord.pid
loglevel=info
user=root
[unix_http_server]
file=/var/run/supervisor.sock
chmod=0700
[rpcinterface:supervisor]
supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
[supervisorctl]
serverurl=unix:///var/run/supervisor.sock
# PostgreSQL
[program:postgresql]
command=/usr/lib/postgresql/14/bin/postgres -D /data/postgres
user=postgres
autostart=true
autorestart=true
priority=10
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
environment=PGDATA="/data/postgres"
# Redis
[program:redis]
command=/usr/bin/redis-server --dir /data/redis --appendonly yes
autostart=true
autorestart=true
priority=20
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
# Backend API
[program:backend]
command=python main.py
directory=/app/backend
autostart=true
autorestart=true
priority=30
startsecs=10
startretries=3
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
environment=PYTHONPATH="/app/backend",UVICORN_LOOP="asyncio",UNSTRUCTURED_HAS_PATCHED_LOOP="1"
# Celery Worker
[program:celery-worker]
command=celery -A app.celery_app worker --loglevel=info --concurrency=2 --pool=solo --queues=surfsense,surfsense.connectors
directory=/app/backend
autostart=true
autorestart=true
priority=40
startsecs=15
startretries=3
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
environment=PYTHONPATH="/app/backend"
# Celery Beat (scheduler)
[program:celery-beat]
command=celery -A app.celery_app beat --loglevel=info
directory=/app/backend
autostart=true
autorestart=true
priority=50
startsecs=20
startretries=3
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
environment=PYTHONPATH="/app/backend"
# Electric SQL (real-time sync)
[program:electric]
command=/app/electric-release/bin/entrypoint start
autostart=true
autorestart=true
priority=25
startsecs=10
startretries=3
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
environment=DATABASE_URL="%(ENV_ELECTRIC_DATABASE_URL)s",ELECTRIC_INSECURE="%(ENV_ELECTRIC_INSECURE)s",ELECTRIC_WRITE_TO_PG_MODE="%(ENV_ELECTRIC_WRITE_TO_PG_MODE)s",RELEASE_COOKIE="surfsense_electric_cookie",PORT="%(ENV_ELECTRIC_PORT)s"
# Frontend
[program:frontend]
command=node server.js
directory=/app/frontend
autostart=true
autorestart=true
priority=60
startsecs=5
startretries=3
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
environment=NODE_ENV="production",PORT="3000",HOSTNAME="0.0.0.0"
# Process Groups
[group:surfsense]
programs=postgresql,redis,electric,backend,celery-worker,celery-beat,frontend
priority=999

View file

@ -167,37 +167,12 @@ LANGSMITH_ENDPOINT=https://api.smith.langchain.com
LANGSMITH_API_KEY=lsv2_pt_.....
LANGSMITH_PROJECT=surfsense
# Uvicorn Server Configuration
# Full documentation for Uvicorn options can be found at: https://www.uvicorn.org/#command-line-options
UVICORN_HOST="0.0.0.0"
UVICORN_PORT=8000
UVICORN_LOG_LEVEL=info
# OPTIONAL: Advanced Uvicorn Options (uncomment to use)
# UVICORN_PROXY_HEADERS=false
# UVICORN_FORWARDED_ALLOW_IPS="127.0.0.1"
# UVICORN_WORKERS=1
# UVICORN_ACCESS_LOG=true
# UVICORN_LOOP="auto"
# UVICORN_HTTP="auto"
# UVICORN_WS="auto"
# UVICORN_LIFESPAN="auto"
# UVICORN_LOG_CONFIG=""
# UVICORN_SERVER_HEADER=true
# UVICORN_DATE_HEADER=true
# UVICORN_LIMIT_CONCURRENCY=
# UVICORN_LIMIT_MAX_REQUESTS=
# UVICORN_TIMEOUT_KEEP_ALIVE=5
# UVICORN_TIMEOUT_NOTIFY=30
# UVICORN_SSL_KEYFILE=""
# UVICORN_SSL_CERTFILE=""
# UVICORN_SSL_KEYFILE_PASSWORD=""
# UVICORN_SSL_VERSION=""
# UVICORN_SSL_CERT_REQS=""
# UVICORN_SSL_CA_CERTS=""
# UVICORN_SSL_CIPHERS=""
# UVICORN_HEADERS=""
# UVICORN_USE_COLORS=true
# UVICORN_UDS=""
# UVICORN_FD=""
# UVICORN_ROOT_PATH=""
# Agent Specific Configuration
# Daytona Sandbox (secure cloud code execution for deep agent)
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
DAYTONA_SANDBOX_ENABLED=TRUE
DAYTONA_API_KEY=dtn_asdasfasfafas
DAYTONA_API_URL=https://app.daytona.io/api
DAYTONA_TARGET=us
# Directory for locally-persisted sandbox files (after sandbox deletion)
SANDBOX_FILES_DIR=sandbox_files

View file

@ -6,6 +6,7 @@ __pycache__/
.flashrank_cache
surf_new_backend.egg-info/
podcasts/
sandbox_files/
temp_audio/
celerybeat-schedule*
celerybeat-schedule.*

View file

@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
ENV PYTHONPATH=/app
ENV UVICORN_LOOP=asyncio
# Tune glibc malloc to return freed memory to the OS more aggressively.
# Without these, Python's gc.collect() frees objects but the underlying
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
ENV MALLOC_MMAP_THRESHOLD_=65536
ENV MALLOC_TRIM_THRESHOLD_=131072
ENV MALLOC_MMAP_MAX_=65536
# SERVICE_ROLE controls which process this container runs:
# api FastAPI backend only (runs migrations on startup)
# worker Celery worker only

View file

@ -0,0 +1,46 @@
"""102_add_enable_summary_to_connectors
Revision ID: 102
Revises: 101
Create Date: 2026-02-26
Adds enable_summary boolean column to search_source_connectors.
Defaults to False for all existing and new connectors so LLM-based
summary generation is opt-in.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "102"
down_revision: str | None = "101"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
conn = op.get_bind()
existing_columns = [
col["name"] for col in sa.inspect(conn).get_columns("search_source_connectors")
]
if "enable_summary" not in existing_columns:
op.add_column(
"search_source_connectors",
sa.Column(
"enable_summary",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("search_source_connectors", "enable_summary")

View file

@ -8,7 +8,7 @@ Creates notifications table and sets up Electric SQL replication
search_source_connectors, and documents tables.
NOTE: Electric SQL user creation is idempotent (uses IF NOT EXISTS).
- Docker deployments: user is pre-created by scripts/docker/init-electric-user.sh
- Docker deployments: user is pre-created by docker/scripts/init-electric-user.sh
- Local PostgreSQL: user is created here during migration
Both approaches are safe to run together without conflicts as this migraiton is idempotent
"""

View file

@ -6,10 +6,14 @@ with configurable tools via the tools registry and configurable prompts
via NewLLMConfig.
"""
import asyncio
import logging
import time
from collections.abc import Sequence
from typing import Any
from deepagents import create_deep_agent
from deepagents.backends.protocol import SandboxBackendProtocol
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -24,6 +28,9 @@ from app.agents.new_chat.system_prompt import (
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
# =============================================================================
# Connector Type Mapping
@ -128,6 +135,7 @@ async def create_surfsense_deep_agent(
additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_backend: SandboxBackendProtocol | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
@ -167,6 +175,9 @@ async def create_surfsense_deep_agent(
These are always added regardless of enabled/disabled settings.
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
Falls back to Chromium/Trafilatura if not provided.
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
secure code execution. When provided, the agent gets an
isolated ``execute`` tool for running shell commands.
Returns:
CompiledStateGraph: The configured deep agent
@ -205,32 +216,41 @@ async def create_surfsense_deep_agent(
additional_tools=[my_custom_tool]
)
"""
_t_agent_total = time.perf_counter()
# Discover available connectors and document types for this search space
# This enables dynamic tool docstrings that inform the LLM about what's actually available
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
# Get enabled search source connectors for this search space
connector_types = await connector_service.get_available_connectors(
search_space_id
)
if connector_types:
# Convert enum values to strings and also include mapped document types
available_connectors = _map_connectors_to_searchable_types(connector_types)
# Get document types that have at least one document indexed
available_document_types = await connector_service.get_available_document_types(
search_space_id
)
except Exception as e:
# Log but don't fail - fall back to all connectors if discovery fails
import logging
logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
time.perf_counter() - _t0,
)
# Build dependencies dict for the tools registry
visibility = thread_visibility or ChatVisibility.PRIVATE
# Extract the model's context window so tools can size their output.
_model_profile = getattr(llm, "profile", None)
_max_input_tokens: int | None = (
_model_profile.get("max_input_tokens")
if isinstance(_model_profile, dict)
else None
)
dependencies = {
"search_space_id": search_space_id,
"db_session": db_session,
@ -241,6 +261,7 @@ async def create_surfsense_deep_agent(
"thread_visibility": visibility,
"available_connectors": available_connectors,
"available_document_types": available_document_types,
"max_input_tokens": _max_input_tokens,
}
# Disable Notion action tools if no Notion connector is configured
@ -269,35 +290,61 @@ async def create_surfsense_deep_agent(
modified_disabled_tools.extend(linear_tools)
# Build tools using the async registry (includes MCP tools)
_t0 = time.perf_counter()
tools = await build_tools_async(
dependencies=dependencies,
enabled_tools=enabled_tools,
disabled_tools=modified_disabled_tools,
additional_tools=list(additional_tools) if additional_tools else None,
)
_perf_log.info(
"[create_agent] build_tools_async in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
# Build system prompt based on agent_config
_t0 = time.perf_counter()
_sandbox_enabled = sandbox_backend is not None
if agent_config is not None:
# Use configurable prompt with settings from NewLLMConfig
system_prompt = build_configurable_system_prompt(
custom_system_instructions=agent_config.system_instructions,
use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled,
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
)
# Create the deep agent with system prompt and checkpointer
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
agent = create_deep_agent(
# Build optional kwargs for the deep agent
deep_agent_kwargs: dict[str, Any] = {}
if sandbox_backend is not None:
deep_agent_kwargs["backend"] = sandbox_backend
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
create_deep_agent,
model=llm,
tools=tools,
system_prompt=system_prompt,
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer,
**deep_agent_kwargs,
)
_perf_log.info(
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
time.perf_counter() - _t0,
)
_perf_log.info(
"[create_agent] Total agent creation in %.3fs",
time.perf_counter() - _t_agent_total,
)
return agent

View file

@ -22,6 +22,7 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)
@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
return ChatLiteLLMRouter()
return get_auto_mode_llm()
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None

View file

@ -0,0 +1,282 @@
"""
Daytona sandbox provider for SurfSense deep agent.
Manages the lifecycle of sandboxed code execution environments.
Each conversation thread gets its own isolated sandbox instance
via the Daytona cloud API, identified by labels.
Files created during a session are persisted to local storage before
the sandbox is deleted so they remain downloadable after cleanup.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import shutil
from pathlib import Path
from daytona import (
CreateSandboxFromSnapshotParams,
Daytona,
DaytonaConfig,
SandboxState,
)
from daytona.common.errors import DaytonaError
from deepagents.backends.protocol import ExecuteResponse
from langchain_daytona import DaytonaSandbox
logger = logging.getLogger(__name__)
class _TimeoutAwareSandbox(DaytonaSandbox):
"""DaytonaSandbox subclass that accepts the per-command *timeout*
kwarg required by the deepagents middleware.
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
so deepagents raises *"This sandbox backend does not support
per-command timeout overrides"* on every first call. This thin
wrapper forwards the parameter to the Daytona SDK.
"""
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
t = timeout if timeout is not None else self._timeout
result = self._sandbox.process.exec(command, timeout=t)
return ExecuteResponse(
output=result.result,
exit_code=result.exit_code,
truncated=False,
)
async def aexecute(
self, command: str, *, timeout: int | None = None
) -> ExecuteResponse: # type: ignore[override]
return await asyncio.to_thread(self.execute, command, timeout=timeout)
_daytona_client: Daytona | None = None
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"
def is_sandbox_enabled() -> bool:
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
def _get_client() -> Daytona:
global _daytona_client
if _daytona_client is None:
config = DaytonaConfig(
api_key=os.environ.get("DAYTONA_API_KEY", ""),
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
target=os.environ.get("DAYTONA_TARGET", "us"),
)
_daytona_client = Daytona(config)
return _daytona_client
def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
"""Find an existing sandbox for *thread_id*, or create a new one.
If an existing sandbox is found but is stopped/archived, it will be
restarted automatically before returning.
"""
client = _get_client()
labels = {THREAD_LABEL_KEY: thread_id}
try:
sandbox = client.find_one(labels=labels)
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
if sandbox.state in (
SandboxState.STOPPED,
SandboxState.STOPPING,
SandboxState.ARCHIVED,
):
logger.info("Starting stopped sandbox %s", sandbox.id)
sandbox.start(timeout=60)
logger.info("Sandbox %s is now started", sandbox.id)
elif sandbox.state in (
SandboxState.ERROR,
SandboxState.BUILD_FAILED,
SandboxState.DESTROYED,
):
logger.warning(
"Sandbox %s in unrecoverable state %s — creating a new one",
sandbox.id,
sandbox.state,
)
sandbox = client.create(
CreateSandboxFromSnapshotParams(language="python", labels=labels)
)
logger.info("Created replacement sandbox: %s", sandbox.id)
elif sandbox.state != SandboxState.STARTED:
sandbox.wait_for_sandbox_start(timeout=60)
except Exception:
logger.info("No existing sandbox for thread %s — creating one", thread_id)
sandbox = client.create(
CreateSandboxFromSnapshotParams(language="python", labels=labels)
)
logger.info("Created new sandbox: %s", sandbox.id)
return _TimeoutAwareSandbox(sandbox=sandbox)
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
"""Get or create a sandbox for a conversation thread.
Uses an in-process cache keyed by thread_id so subsequent messages
in the same conversation reuse the sandbox object without an API call.
Args:
thread_id: The conversation thread identifier.
Returns:
DaytonaSandbox connected to the sandbox.
"""
key = str(thread_id)
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached
sandbox = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
_sandbox_cache.pop(oldest_key, None)
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
return sandbox
async def delete_sandbox(thread_id: int | str) -> None:
"""Delete the sandbox for a conversation thread."""
_sandbox_cache.pop(str(thread_id), None)
def _delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except DaytonaError:
logger.debug(
"No sandbox to delete for thread %s (already removed)", thread_id
)
return
try:
client.delete(sandbox)
logger.info("Sandbox deleted: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox for thread %s",
thread_id,
exc_info=True,
)
await asyncio.to_thread(_delete)
# ---------------------------------------------------------------------------
# Local file persistence
# ---------------------------------------------------------------------------
def _get_sandbox_files_dir() -> Path:
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
"""Map a sandbox-internal absolute path to a local filesystem path."""
relative = sandbox_path.lstrip("/")
return _get_sandbox_files_dir() / str(thread_id) / relative
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
"""Read a previously-persisted sandbox file from local storage.
Returns the file bytes, or *None* if the file does not exist locally.
"""
local = _local_path_for(thread_id, sandbox_path)
if local.is_file():
return local.read_bytes()
return None
def delete_local_sandbox_files(thread_id: int | str) -> None:
"""Remove all locally-persisted sandbox files for a thread."""
thread_dir = _get_sandbox_files_dir() / str(thread_id)
if thread_dir.is_dir():
shutil.rmtree(thread_dir, ignore_errors=True)
logger.info("Deleted local sandbox files for thread %s", thread_id)
async def persist_and_delete_sandbox(
thread_id: int | str,
sandbox_file_paths: list[str],
) -> None:
"""Download sandbox files to local storage, then delete the sandbox.
Each file in *sandbox_file_paths* is downloaded from the Daytona
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/``.
Per-file errors are logged but do **not** prevent the sandbox from
being deleted freeing Daytona storage is the priority.
"""
_sandbox_cache.pop(str(thread_id), None)
def _persist_and_delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except Exception:
logger.info(
"No sandbox found for thread %s — nothing to persist", thread_id
)
return
# Ensure the sandbox is running so we can download files
if sandbox.state != SandboxState.STARTED:
try:
sandbox.start(timeout=60)
except Exception:
logger.warning(
"Could not start sandbox %s for file download — deleting anyway",
sandbox.id,
exc_info=True,
)
with contextlib.suppress(Exception):
client.delete(sandbox)
return
for path in sandbox_file_paths:
try:
content: bytes = sandbox.fs.download_file(path)
local = _local_path_for(thread_id, path)
local.parent.mkdir(parents=True, exist_ok=True)
local.write_bytes(content)
logger.info("Persisted sandbox file %s%s", path, local)
except Exception:
logger.warning(
"Failed to persist sandbox file %s for thread %s",
path,
thread_id,
exc_info=True,
)
try:
client.delete(sandbox)
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox %s after persistence",
sandbox.id,
exc_info=True,
)
await asyncio.to_thread(_persist_and_delete)

View file

@ -645,6 +645,87 @@ However, from your video learning, it's important to note that asyncio is not su
</citation_instructions>
"""
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
SANDBOX_EXECUTION_INSTRUCTIONS = """
<code_execution>
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
## CRITICAL — CODE-FIRST RULE
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
- **Creating a chart, plot, graph, or visualization** Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
- **Data analysis, statistics, or computation** Write code to compute the answer. Do not do math by hand in text.
- **Generating or transforming files** (CSV, PDF, images, etc.) Write code to create the file.
- **Running, testing, or debugging code** Execute it in the sandbox.
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
Example (CORRECT):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Immediately execute Python code (matplotlib) to generate the pie chart
3. Return the downloadable file + brief description
Example (WRONG):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Print a text table with percentages and ask the user if they want a chart NEVER do this
## When to Use Code Execution
Use the sandbox when the task benefits from actually running code rather than just describing it:
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
- **Calculations**: Math, financial modeling, unit conversions, simulations
- **Code validation**: Run and test code snippets the user provides or asks about
- **File processing**: Parse, transform, or convert data files
- **Quick prototyping**: Demonstrate working code for the user's problem
- **Package exploration**: Install and test libraries the user is evaluating
## When NOT to Use Code Execution
Do not use the sandbox for:
- Answering factual questions from your own knowledge
- Summarizing or explaining concepts
- Simple formatting or text generation tasks
- Tasks that don't require running code to answer
## Package Management
- Use `pip install <package>` to install Python packages as needed
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
- Always verify a package installed successfully before using it
## Working Guidelines
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
- **Iterative approach**: For complex tasks, break work into steps write code, run it, check output, refine
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
- **Be efficient**: Install packages once per session. Combine related commands when possible.
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
## Sharing Generated Files
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
- Always describe what the file contains in your response text so the user knows what they are downloading.
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
## Data Analytics Best Practices
When the user asks you to analyze data:
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
2. Clean and validate before computing (handle nulls, check types)
3. Perform the analysis and present results clearly
4. Offer follow-up insights or visualizations when appropriate
</code_execution>
"""
# Anti-citation prompt - used when citations are disabled
# This explicitly tells the model NOT to include citations
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
@ -670,6 +751,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
) -> str:
"""
Build the SurfSense system prompt with default settings.
@ -678,10 +760,12 @@ def build_surfsense_system_prompt(
- Default system instructions
- Tools instructions (always included)
- Citation instructions enabled
- Sandbox execution instructions (when sandbox_enabled=True)
Args:
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
Returns:
Complete system prompt string
@ -691,7 +775,13 @@ def build_surfsense_system_prompt(
system_instructions = _get_system_instructions(visibility, today)
tools_instructions = _get_tools_instructions(visibility)
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
return system_instructions + tools_instructions + citation_instructions
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
def build_configurable_system_prompt(
@ -700,14 +790,16 @@ def build_configurable_system_prompt(
citations_enabled: bool = True,
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
The prompt is composed of three parts:
The prompt is composed of up to four parts:
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
2. Tools Instructions - always included (SURFSENSE_TOOLS_INSTRUCTIONS)
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
4. Sandbox Execution Instructions - when sandbox_enabled=True
Args:
custom_system_instructions: Custom system instructions to use. If empty/None and
@ -719,6 +811,7 @@ def build_configurable_system_prompt(
anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
Returns:
Complete system prompt string
@ -727,7 +820,6 @@ def build_configurable_system_prompt(
# Determine system instructions
if custom_system_instructions and custom_system_instructions.strip():
# Use custom instructions, injecting the date placeholder if present
system_instructions = custom_system_instructions.format(
resolved_today=resolved_today
)
@ -735,7 +827,6 @@ def build_configurable_system_prompt(
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
else:
# No system instructions (edge case)
system_instructions = ""
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
@ -748,7 +839,14 @@ def build_configurable_system_prompt(
else SURFSENSE_NO_CITATION_INSTRUCTIONS
)
return system_instructions + tools_instructions + citation_instructions
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
def get_default_system_instructions() -> str:

View file

@ -58,7 +58,9 @@ def create_create_google_drive_file_tool(
- "Create a Google Doc called 'Meeting Notes'"
- "Create a spreadsheet named 'Budget 2026' with some sample data"
"""
logger.info(f"create_google_drive_file called: name='{name}', type='{file_type}'")
logger.info(
f"create_google_drive_file called: name='{name}', type='{file_type}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {
@ -74,7 +76,9 @@ def create_create_google_drive_file_tool(
try:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_creation_context(search_space_id, user_id)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
@ -100,8 +104,12 @@ def create_create_google_drive_file_tool(
}
)
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
@ -183,7 +191,9 @@ def create_create_google_drive_file_tool(
logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
client = GoogleDriveClient(session=db_session, connector_id=actual_connector_id)
client = GoogleDriveClient(
session=db_session, connector_id=actual_connector_id
)
try:
created = await client.create_file(
name=final_name,
@ -203,7 +213,9 @@ def create_create_google_drive_file_tool(
}
raise
logger.info(f"Google Drive file created: id={created.get('id')}, name={created.get('name')}")
logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
)
return {
"status": "success",
"file_id": created.get("id"),

View file

@ -52,7 +52,9 @@ def create_delete_google_drive_file_tool(
- "Delete the 'Meeting Notes' file from Google Drive"
- "Trash the 'Old Budget' spreadsheet"
"""
logger.info(f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}")
logger.info(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
return {
@ -103,8 +105,12 @@ def create_delete_google_drive_file_tool(
}
)
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
@ -130,11 +136,16 @@ def create_delete_google_drive_file_tool(
final_params = decision["args"]
final_file_id = final_params.get("file_id", file_id)
final_connector_id = final_params.get("connector_id", connector_id_from_context)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {"status": "error", "message": "No connector found for this file."}
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
@ -174,7 +185,9 @@ def create_delete_google_drive_file_tool(
}
raise
logger.info(f"Google Drive file deleted (moved to trash): file_id={final_file_id}")
logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
@ -195,7 +208,9 @@ def create_delete_google_drive_file_tool(
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(f"Deleted document {document_id} from knowledge base")
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:

View file

@ -10,6 +10,8 @@ This module provides:
import asyncio
import json
import re
import time
from datetime import datetime
from typing import Any
@ -17,8 +19,152 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
# Connectors that call external live-search APIs (no local DB / embedding needed).
# These are never filtered by available_document_types.
_LIVE_SEARCH_CONNECTORS: set[str] = {
"TAVILY_API",
"SEARXNG_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
# Patterns that indicate the query has no meaningful search signal.
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
# of '*' is random noise, so both keyword and semantic search degrade to
# arbitrary ordering — large documents (many chunks) dominate by chance.
_DEGENERATE_QUERY_RE = re.compile(
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
)
# Max chunks per document when doing a recency-based browse instead of
# a real search. We want breadth (many docs) over depth (many chunks).
_BROWSE_MAX_CHUNKS_PER_DOC = 5
def _is_degenerate_query(query: str) -> bool:
"""Return True when the query carries no meaningful search signal.
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
strings, and single-character non-word tokens. These queries cause
both keyword search (empty tsquery) and semantic search (meaningless
embedding) to return effectively random results.
"""
stripped = query.strip()
if not stripped:
return True
return bool(_DEGENERATE_QUERY_RE.match(stripped))
async def _browse_recent_documents(
search_space_id: int,
document_type: str | None,
top_k: int,
start_date: datetime | None,
end_date: datetime | None,
) -> list[dict[str, Any]]:
"""Return the most-recent documents (recency-ordered, no search ranking).
Used as a fallback when the search query is degenerate (e.g. ``*``) and
semantic / keyword search would produce arbitrary results. Returns
document-grouped dicts in the same shape as ``_combined_rrf_search``
so the rest of the pipeline works unchanged.
"""
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, DocumentType
perf = get_perf_logger()
t0 = time.perf_counter()
base_conditions = [Document.search_space_id == search_space_id]
if document_type is not None:
if isinstance(document_type, str):
try:
doc_type_enum = DocumentType[document_type]
base_conditions.append(Document.document_type == doc_type_enum)
except KeyError:
return []
else:
base_conditions.append(Document.document_type == document_type)
if start_date is not None:
base_conditions.append(Document.updated_at >= start_date)
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
async with shielded_async_session() as session:
doc_query = (
select(Document)
.options(joinedload(Document.search_space))
.where(*base_conditions)
.order_by(Document.updated_at.desc())
.limit(top_k)
)
result = await session.execute(doc_query)
documents = result.scalars().unique().all()
if not documents:
return []
doc_ids = [d.id for d in documents]
chunk_query = (
select(Chunk)
.where(Chunk.document_id.in_(doc_ids))
.order_by(Chunk.document_id, Chunk.id)
)
chunk_result = await session.execute(chunk_query)
raw_chunks = chunk_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
doc_chunk_counts[did] = count + 1
results: list[dict[str, Any]] = []
for doc in documents:
chunks_list = doc_chunks.get(doc.id, [])
results.append(
{
"document_id": doc.id,
"content": "\n\n".join(
c["content"] for c in chunks_list if c.get("content")
),
"score": 0.0,
"chunks": chunks_list,
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
}
)
perf.info(
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(results),
search_space_id,
document_type,
)
return results
# =============================================================================
# Connector Constants and Normalization
@ -172,12 +318,72 @@ def _normalize_connectors(
# =============================================================================
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
# Fraction of the model's context window (in characters) that a single tool
# result is allowed to occupy. The remainder is reserved for system prompt,
# conversation history, and model output. With ~4 chars/token this gives a
# tool result ≈ 25 % of the context budget in tokens.
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
_CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
_MAX_CHUNK_CHARS = 8_000
# Rank-adaptive per-document budget allocation.
# Top-ranked (most relevant) documents get a larger share of the budget so
# we pack as much high-quality context as possible.
#
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
#
# Examples (128K budget, 8K chunk cap):
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
# rank 2 → 24% → 3 chunks |
_TOP_DOC_BUDGET_FRACTION = 0.40
_RANK_DECAY = 0.35
_MIN_CHUNKS_PER_DOC = 3
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window.
Uses ``litellm.get_model_info`` via the value already resolved by
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
chain as ``max_input_tokens``. Falls back to a conservative default when
the value is unavailable.
"""
if max_input_tokens is None or max_input_tokens <= 0:
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
def format_documents_for_context(
documents: list[dict[str, Any]],
*,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS,
max_chunks_per_doc: int = 0,
) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
each document is limited to a dynamically computed chunk cap so a single
large document cannot monopolize the output while still maximising the use
of available context space.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
auto-compute per document using a rank-adaptive formula so
higher-ranked documents receive more chunks.
Returns:
Formatted string with document contents and metadata
@ -278,39 +484,85 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
"BAIDU_SEARCH_API",
}
# Render XML expected by citation instructions
# Render XML expected by citation instructions, respecting the char budget.
parts: list[str] = []
for g in grouped.values():
total_chars = 0
total_docs = len(grouped)
for doc_idx, g in enumerate(grouped.values()):
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
is_live_search = g["document_type"] in live_search_connectors
parts.append("<document>")
parts.append("<document_metadata>")
parts.append(f" <document_id>{g['document_id']}</document_id>")
parts.append(f" <document_type>{g['document_type']}</document_type>")
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
parts.append("</document_metadata>")
parts.append("")
parts.append("<document_content>")
doc_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{g['document_id']}</document_id>",
f" <document_type>{g['document_type']}</document_type>",
f" <title><![CDATA[{g['title']}]]></title>",
f" <url><![CDATA[{g['url']}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
"<document_content>",
]
for ch in g["chunks"]:
# Rank-adaptive per-document chunk cap: top results get more chunks.
if max_chunks_per_doc > 0:
chunks_allowed = max_chunks_per_doc
else:
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
max_doc_chars = int(max_chars * doc_fraction)
xml_overhead = 500
chunks_allowed = max(
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
_MIN_CHUNKS_PER_DOC,
)
chunks = g["chunks"]
if len(chunks) > chunks_allowed:
chunks = chunks[:chunks_allowed]
for ch in chunks:
ch_content = ch["content"]
# For live search connectors, use the document URL as the chunk id
# so the LLM outputs [citation:https://...] which the frontend
# renders as a clickable link.
if max_chunk_chars and len(ch_content) > max_chunk_chars:
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
if ch_id is None:
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
else:
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
doc_lines.append(
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
)
parts.append("</document_content>")
parts.append("</document>")
parts.append("")
doc_lines.extend(["</document_content>", "</document>", ""])
return "\n".join(parts).strip()
doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml)
if total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx
if doc_idx == 0:
parts.append(doc_xml)
total_chars += doc_len
parts.append(
f"<!-- Output truncated: {remaining} more document(s) omitted "
f"(budget {max_chars} chars). Refine your query or reduce top_k "
f"to retrieve different results. -->"
)
break
parts.append(doc_xml)
total_chars += doc_len
result = "\n".join(parts).strip()
# Hard safety net: if the result is still over budget (e.g. a single massive
# first document), forcibly truncate with a closing comment.
if len(result) > max_chars:
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
result = result[: max_chars - len(truncation_msg)] + truncation_msg
return result
# =============================================================================
@ -328,6 +580,8 @@ async def search_knowledge_base_async(
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str:
"""
Search the user's knowledge base for relevant documents.
@ -345,10 +599,18 @@ async def search_knowledge_base_async(
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
available_document_types: Optional list of document types that actually have indexed
data. When provided, local connectors whose document type is
absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
"""
perf = get_perf_logger()
t0 = time.perf_counter()
all_documents: list[dict[str, Any]] = []
# Resolve date range (default last 2 years)
@ -361,88 +623,169 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
"EXTENSION": ("search_extension", True, True, {}),
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
"FILE": ("search_files", True, True, {}),
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
# --- Optimization 1: skip local connectors that have zero indexed documents ---
if available_document_types:
doc_types_set = set(available_document_types)
before_count = len(connectors)
connectors = [
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
]
skipped = before_count - len(connectors)
if skipped:
perf.info(
"[kb_search] skipped %d empty connectors (had %d, now %d)",
skipped,
before_count,
len(connectors),
)
perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors),
connectors[:5],
search_space_id,
top_k,
)
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
# yields an empty tsquery, so both retrieval signals are useless.
# Fall back to a recency-ordered browse that returns diverse results.
if _is_degenerate_query(query):
perf.info(
"[kb_search] degenerate query %r detected - falling back to recency browse",
query,
)
local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
if not local_connectors:
local_connectors = [None] # type: ignore[list-item]
browse_results = await asyncio.gather(
*[
_browse_recent_documents(
search_space_id=search_space_id,
document_type=c,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
for c in local_connectors
]
)
for docs in browse_results:
all_documents.extend(docs)
# Skip dedup + formatting below (browse already returns unique docs)
# but still cap output budget.
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
all_documents,
max_chars=output_budget,
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
)
perf.info(
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
"budget=%d space=%d",
time.perf_counter() - t0,
len(all_documents),
len(result),
output_budget,
search_space_id,
)
return result
# Specs for live-search connectors (external APIs, no local DB/embedding).
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"TAVILY_API": ("search_tavily", False, True, {}),
"SEARXNG_API": ("search_searxng", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
"NOTE": ("search_notes", True, True, {}),
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
"CIRCLEBACK": ("search_circleback", True, True, {}),
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
# Composio connectors
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
"search_composio_google_drive",
True,
True,
{},
),
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
"search_composio_google_calendar",
True,
True,
{},
),
}
# Keep a conservative cap to avoid overloading DB/external services.
# --- Optimization 2: compute the query embedding once, share across all local searches ---
precomputed_embedding: list[float] | None = None
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
if has_local_connectors:
from app.config import config as app_config
t_embed = time.perf_counter()
precomputed_embedding = app_config.embedding_model_instance.embed(query)
perf.info(
"[kb_search] shared embedding computed in %.3fs",
time.perf_counter() - t_embed,
)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
spec = connector_specs.get(connector)
if spec is None:
return []
is_live = connector in _LIVE_SEARCH_CONNECTORS
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
if includes_date_range:
kwargs["start_date"] = resolved_start_date
kwargs["end_date"] = resolved_end_date
if is_live:
spec = live_connector_specs.get(connector)
if spec is None:
return []
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
if includes_date_range:
kwargs["start_date"] = resolved_start_date
kwargs["end_date"] = resolved_end_date
try:
t_conn = time.perf_counter()
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
_, chunks = await getattr(svc, method_name)(**kwargs)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
return chunks
except Exception as e:
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
try:
# Use isolated session per connector. Shared AsyncSession cannot safely
# run concurrent DB operations.
async with semaphore, async_session_maker() as isolated_session:
isolated_connector_service = ConnectorService(
isolated_session, search_space_id
t_conn = time.perf_counter()
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=precomputed_embedding,
)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
connector_method = getattr(isolated_connector_service, method_name)
_, chunks = await connector_method(**kwargs)
return chunks
except Exception as e:
print(f"Error searching connector {connector}: {e}")
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
t_gather = time.perf_counter()
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
perf.info(
"[kb_search] all connectors gathered in %.3fs",
time.perf_counter() - t_gather,
)
for chunks in connector_results:
all_documents.extend(chunks)
@ -488,7 +831,29 @@ async def search_knowledge_base_async(
deduplicated.append(doc)
return format_documents_for_context(deduplicated)
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(deduplicated, max_chars=output_budget)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
@ -526,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
"""Input schema for the search_knowledge_base tool."""
query: str = Field(
description="The search query - be specific and include key terms"
description=(
"The search query - use specific natural language terms. "
"NEVER use wildcards like '*' or '**'; instead describe what you want "
"(e.g. 'recent meeting notes' or 'project architecture overview')."
),
)
top_k: int = Field(
default=10,
description="Number of results to retrieve (default: 10)",
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
)
start_date: str | None = Field(
default=None,
@ -552,6 +921,7 @@ def create_search_knowledge_base_tool(
connector_service: ConnectorService,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> StructuredTool:
"""
Factory function to create the search_knowledge_base tool with injected dependencies.
@ -564,6 +934,8 @@ def create_search_knowledge_base_tool(
Used to dynamically generate the tool docstring.
available_document_types: Optional list of document types that have data in the search space.
Used to inform the LLM about what data exists.
max_input_tokens: Model context window (tokens) from litellm model info.
Used to dynamically size tool output.
Returns:
A configured StructuredTool instance
@ -590,6 +962,10 @@ Focus searches on these types for best results."""
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
IMPORTANT:
- Always craft specific, descriptive search queries using natural language keywords.
Good: "quarterly sales report Q3", "Python API authentication design".
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
- Prefer multiple focused searches over a single broad one with high top_k.
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly.
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
@ -605,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
# Capture for closure
_available_connectors = available_connectors
_available_document_types = available_document_types
async def _search_knowledge_base_impl(
query: str,
@ -634,6 +1011,8 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
start_date=parsed_start,
end_date=parsed_end,
available_connectors=_available_connectors,
available_document_types=_available_document_types,
max_input_tokens=max_input_tokens,
)
# Create StructuredTool with dynamic description

View file

@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation.
"""
import logging
import time
from typing import Any
from langchain_core.tools import StructuredTool
@ -25,6 +26,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
def _evict_expired_mcp_cache() -> None:
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
now = time.monotonic()
expired = [
k
for k, (ts, _) in _mcp_tools_cache.items()
if now - ts >= _MCP_CACHE_TTL_SECONDS
]
for k in expired:
del _mcp_tools_cache[k]
if expired:
logger.debug("Evicted %d expired MCP cache entries", len(expired))
def _create_dynamic_input_model_from_schema(
tool_name: str,
@ -355,6 +374,19 @@ async def _load_http_mcp_tools(
return tools
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
Args:
search_space_id: If provided, only invalidate for this search space.
If None, invalidate all cached MCP tools.
"""
if search_space_id is not None:
_mcp_tools_cache.pop(search_space_id, None)
else:
_mcp_tools_cache.clear()
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
@ -364,6 +396,9 @@ async def load_mcp_tools(
This discovers tools dynamically from MCP servers using the protocol.
Supports both stdio (local process) and HTTP (remote server) transports.
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
Args:
session: Database session
search_space_id: User's search space ID
@ -372,8 +407,22 @@ async def load_mcp_tools(
List of LangChain StructuredTool instances
"""
_evict_expired_mcp_cache()
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
if cached is not None:
cached_at, cached_tools = cached
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
logger.info(
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
search_space_id,
len(cached_tools),
now - cached_at,
)
return list(cached_tools)
try:
# Fetch all MCP connectors for this search space
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
@ -385,27 +434,22 @@ async def load_mcp_tools(
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Early validation: Extract and validate connector config
config = connector.config or {}
server_config = config.get("server_config", {})
# Validate server_config exists and is a dict
if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
)
continue
# Determine transport type
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
# HTTP-based MCP server
connector_tools = await _load_http_mcp_tools(
connector.id, connector.name, server_config
)
else:
# stdio-based MCP server (default)
connector_tools = await _load_stdio_mcp_tools(
connector.id, connector.name, server_config
)
@ -417,6 +461,12 @@ async def load_mcp_tools(
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
)
_mcp_tools_cache[search_space_id] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key]
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools

View file

@ -47,6 +47,10 @@ from app.db import ChatVisibility
from .display_image import create_display_image_tool
from .generate_image import create_generate_image_tool
from .google_drive import (
create_create_google_drive_file_tool,
create_delete_google_drive_file_tool,
)
from .knowledge_base import create_search_knowledge_base_tool
from .linear import (
create_create_linear_issue_tool,
@ -55,10 +59,6 @@ from .linear import (
)
from .link_preview import create_link_preview_tool
from .mcp_tool import load_mcp_tools
from .google_drive import (
create_create_google_drive_file_tool,
create_delete_google_drive_file_tool,
)
from .notion import (
create_create_notion_page_tool,
create_delete_notion_page_tool,
@ -118,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
# Optional: dynamically discovered connectors/document types
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
max_input_tokens=deps.get("max_input_tokens"),
),
requires=["search_space_id", "db_session", "connector_service"],
# Note: available_connectors and available_document_types are optional
@ -144,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
thread_id=deps["thread_id"],
connector_service=deps.get("connector_service"),
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
),
requires=["search_space_id", "thread_id"],
# connector_service and available_connectors are optional —
# when missing, source_strategy="kb_search" degrades gracefully to "provided"
# connector_service, available_connectors, and available_document_types
# are optional — when missing, source_strategy="kb_search" degrades
# gracefully to "provided"
),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(
@ -444,8 +447,18 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools.
"""
# Build standard tools
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
_t0 = time.perf_counter()
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
_perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
# Load MCP tools if requested and dependencies are available
if (
@ -454,10 +467,16 @@ async def build_tools_async(
and "search_space_id" in dependencies
):
try:
_t0 = time.perf_counter()
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
_perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(mcp_tools),
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",

View file

@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from app.db import Report, async_session_maker
from app.db import Report, shielded_async_session
from app.services.connector_service import ConnectorService
from app.services.llm_service import get_document_summary_llm
@ -559,6 +559,7 @@ def create_generate_report_tool(
thread_id: int | None = None,
connector_service: ConnectorService | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
):
"""
Factory function to create the generate_report tool with injected dependencies.
@ -716,7 +717,7 @@ def create_generate_report_tool(
async def _save_failed_report(error_msg: str) -> int | None:
"""Persist a failed report row using a short-lived session."""
try:
async with async_session_maker() as session:
async with shielded_async_session() as session:
failed_report = Report(
title=topic,
content=None,
@ -750,7 +751,7 @@ def create_generate_report_tool(
# ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call.
async with async_session_maker() as read_session:
async with shielded_async_session() as read_session:
if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id)
if parent_report:
@ -827,7 +828,7 @@ def create_generate_report_tool(
# Run all queries in parallel, each with its own session
async def _run_single_query(q: str) -> str:
async with async_session_maker() as kb_session:
async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService(
kb_session, search_space_id
)
@ -838,6 +839,7 @@ def create_generate_report_tool(
connector_service=kb_connector_svc,
top_k=10,
available_connectors=available_connectors,
available_document_types=available_document_types,
)
kb_results = await asyncio.gather(
@ -1026,7 +1028,7 @@ def create_generate_report_tool(
# ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
async with async_session_maker() as write_session:
async with shielded_async_session() as write_session:
report = Report(
title=topic,
content=report_content,

View file

@ -14,8 +14,8 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
from app.utils.document_converters import embed_text
def format_surfsense_docs_results(results: list[tuple]) -> str:
@ -100,7 +100,7 @@ async def search_surfsense_docs_async(
Formatted string with relevant documentation content
"""
# Get embedding for the query
query_embedding = config.embedding_model_instance.embed(query)
query_embedding = embed_text(query)
# Vector similarity search on chunks, joining with documents
stmt = (

View file

@ -8,8 +8,8 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import MemoryCategory, SharedMemory, User
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
@ -64,7 +64,7 @@ async def save_shared_memory(
count = await get_shared_memory_count(db_session, search_space_id)
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
await delete_oldest_shared_memory(db_session, search_space_id)
embedding = config.embedding_model_instance.embed(content)
embedding = embed_text(content)
row = SharedMemory(
search_space_id=search_space_id,
created_by_id=_to_uuid(created_by_id),
@ -108,7 +108,7 @@ async def recall_shared_memory(
if category and category in valid_categories:
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
if query:
query_embedding = config.embedding_model_instance.embed(query)
query_embedding = embed_text(query)
stmt = stmt.order_by(
SharedMemory.embedding.op("<=>")(query_embedding)
).limit(top_k)

View file

@ -17,8 +17,8 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import MemoryCategory, UserMemory
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
@ -178,7 +178,7 @@ def create_save_memory_tool(
await delete_oldest_memory(db_session, user_id, search_space_id)
# Generate embedding for the memory
embedding = config.embedding_model_instance.embed(content)
embedding = embed_text(content)
# Create new memory using ORM
# The pgvector Vector column type handles embedding conversion automatically
@ -268,7 +268,7 @@ def create_recall_memory_tool(
if query:
# Semantic search using embeddings
query_embedding = config.embedding_model_instance.embed(query)
query_embedding = embed_text(query)
# Build query with vector similarity
stmt = (

View file

@ -12,7 +12,7 @@ from litellm import aspeech
from app.config import config as app_config
from app.services.kokoro_tts_service import get_kokoro_tts_service
from app.services.llm_service import get_document_summary_llm
from app.services.llm_service import get_agent_llm
from .configuration import Configuration
from .prompts import get_podcast_generation_prompt
@ -31,7 +31,7 @@ async def create_podcast_transcript(
user_prompt = configuration.user_prompt
# Get search space's document summary LLM
llm = await get_document_summary_llm(state.db_session, search_space_id)
llm = await get_agent_llm(state.db_session, search_space_id)
if not llm:
error_message = (
f"No document summary LLM configured for search space {search_space_id}"

View file

@ -1,4 +1,5 @@
import asyncio
import gc
import logging
import time
from collections import defaultdict
@ -15,6 +16,9 @@ from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from app.agents.new_chat.checkpointer import (
@ -28,6 +32,7 @@ from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.utils.perf import get_perf_logger, log_system_snapshot
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
@ -99,22 +104,24 @@ def _check_rate_limit_memory(
now = time.monotonic()
with _memory_lock:
# Evict timestamps outside the current window
_memory_rate_limits[key] = [
t for t in _memory_rate_limits[key] if now - t < window_seconds
]
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
if len(_memory_rate_limits[key]) >= max_requests:
if not timestamps:
_memory_rate_limits.pop(key, None)
else:
_memory_rate_limits[key] = timestamps
if len(timestamps) >= max_requests:
rate_limit_logger.warning(
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
)
raise HTTPException(
status_code=429,
detail="RATE_LIMIT_EXCEEDED",
)
_memory_rate_limits[key].append(now)
_memory_rate_limits[key] = [*timestamps, now]
def _check_rate_limit(
@ -175,18 +182,47 @@ def rate_limit_password_reset(request: Request):
)
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
This helps pinpoint synchronous code that freezes the entire FastAPI server.
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
"""
import os
if not os.environ.get("PERF_DEBUG"):
return
_slow_log = logging.getLogger("surfsense.perf.slow")
_slow_log.setLevel(logging.WARNING)
if not _slow_log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
_slow_log.addHandler(_h)
_slow_log.propagate = False
loop = asyncio.get_running_loop()
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
loop.set_debug(True)
_slow_log.warning(
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
"Set PERF_DEBUG='' to disable.",
threshold_sec,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Not needed if you setup a migration system like Alembic
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
# with minimal CPU overhead.
gc.set_threshold(700, 10, 5)
_enable_slow_callback_logging(threshold_sec=0.5)
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence
await setup_checkpointer_tables()
# Initialize LLM Router for Auto mode load balancing
initialize_llm_router()
# Initialize Image Generation Router for Auto mode load balancing
initialize_image_gen_router()
# Seed Surfsense documentation (with timeout so a slow embedding API
# doesn't block startup indefinitely and make the container unresponsive)
try:
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
except TimeoutError:
@ -194,8 +230,11 @@ async def lifespan(app: FastAPI):
"Surfsense docs seeding timed out after 120s — skipping. "
"Docs will be indexed on the next restart."
)
log_system_snapshot("startup_complete")
yield
# Cleanup: close checkpointer connection on shutdown
await close_checkpointer()
@ -213,6 +252,63 @@ app = FastAPI(lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# ---------------------------------------------------------------------------
# Request-level performance middleware
# ---------------------------------------------------------------------------
# Logs wall-clock time, method, path, and status for every request so we can
# spot slow endpoints in production logs.
_PERF_SLOW_REQUEST_THRESHOLD = float(
__import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
)
class RequestPerfMiddleware(BaseHTTPMiddleware):
"""Middleware that logs per-request wall-clock time.
- ALL requests are logged at DEBUG level.
- Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
WARNING level with a system snapshot so we can correlate slow responses
with CPU/memory usage at that moment.
"""
async def dispatch(
self, request: StarletteRequest, call_next: RequestResponseEndpoint
) -> StarletteResponse:
perf = get_perf_logger()
t0 = time.perf_counter()
response = await call_next(request)
elapsed_ms = (time.perf_counter() - t0) * 1000
path = request.url.path
method = request.method
status = response.status_code
perf.debug(
"[request] %s %s -> %d in %.1fms",
method,
path,
status,
elapsed_ms,
)
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
perf.warning(
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
method,
path,
status,
elapsed_ms,
_PERF_SLOW_REQUEST_THRESHOLD,
)
log_system_snapshot("slow_request")
return response
app.add_middleware(RequestPerfMiddleware)
# Add SlowAPI middleware for automatic rate limiting
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends

View file

@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.config import config
from app.connectors.composio_connector import ComposioConnector
from app.db import Document, DocumentStatus, DocumentType
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -383,6 +383,7 @@ async def _process_gmail_messages_phase2(
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int]:
"""
@ -415,7 +416,7 @@ async def _process_gmail_messages_phase2(
session, user_id, search_space_id
)
if user_llm:
if user_llm and enable_summary:
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
@ -427,10 +428,8 @@ async def _process_gmail_messages_phase2(
item["markdown_content"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
@ -646,6 +645,7 @@ async def index_composio_gmail(
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)

View file

@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.config import config
from app.connectors.composio_connector import ComposioConnector
from app.db import Document, DocumentStatus, DocumentType
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -440,7 +440,7 @@ async def index_composio_google_calendar(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"summary": item["summary"],
@ -456,12 +456,10 @@ async def index_composio_google_calendar(
document_metadata_for_summary,
)
else:
summary_content = f"Calendar: {item['summary']}\n\nStart: {item['start_time']}\nEnd: {item['end_time']}"
if item["location"]:
summary_content += f"\nLocation: {item['location']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
summary_content = (
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])

View file

@ -31,6 +31,7 @@ from app.tasks.connector_indexers.base import (
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -714,6 +715,7 @@ async def index_composio_google_drive(
max_items=max_items,
task_logger=task_logger,
log_entry=log_entry,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)
else:
@ -747,6 +749,7 @@ async def index_composio_google_drive(
max_items=max_items,
task_logger=task_logger,
log_entry=log_entry,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)
@ -829,6 +832,7 @@ async def _index_composio_drive_delta_sync(
max_items: int,
task_logger: TaskLoggingService,
log_entry,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index Google Drive files using delta sync with real-time document status updates.
@ -1079,7 +1083,7 @@ async def _index_composio_drive_delta_sync(
session, user_id, search_space_id
)
if user_llm:
if user_llm and enable_summary:
document_metadata_for_summary = {
"file_id": item["file_id"],
"file_name": item["file_name"],
@ -1090,10 +1094,8 @@ async def _index_composio_drive_delta_sync(
markdown_content, user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(markdown_content)
@ -1155,6 +1157,7 @@ async def _index_composio_drive_full_scan(
max_items: int,
task_logger: TaskLoggingService,
log_entry,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index Google Drive files using full scan with real-time document status updates."""
@ -1488,7 +1491,7 @@ async def _index_composio_drive_full_scan(
session, user_id, search_space_id
)
if user_llm:
if user_llm and enable_summary:
document_metadata_for_summary = {
"file_id": item["file_id"],
"file_name": item["file_name"],
@ -1499,10 +1502,8 @@ async def _index_composio_drive_full_scan(
markdown_content, user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(markdown_content)

View file

@ -1,7 +1,9 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from enum import Enum
from enum import StrEnum
import anyio
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from pgvector.sqlalchemy import Vector
@ -31,7 +33,7 @@ if config.AUTH_TYPE == "GOOGLE":
DATABASE_URL = config.DATABASE_URL
class DocumentType(str, Enum):
class DocumentType(StrEnum):
EXTENSION = "EXTENSION"
CRAWLED_URL = "CRAWLED_URL"
FILE = "FILE"
@ -60,7 +62,7 @@ class DocumentType(str, Enum):
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
class SearchSourceConnectorType(str, Enum):
class SearchSourceConnectorType(StrEnum):
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
TAVILY_API = "TAVILY_API"
SEARXNG_API = "SEARXNG_API"
@ -93,7 +95,7 @@ class SearchSourceConnectorType(str, Enum):
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
class PodcastStatus(str, Enum):
class PodcastStatus(StrEnum):
PENDING = "pending"
GENERATING = "generating"
READY = "ready"
@ -177,7 +179,7 @@ class DocumentStatus:
return None
class LiteLLMProvider(str, Enum):
class LiteLLMProvider(StrEnum):
"""
Enum for LLM providers supported by LiteLLM.
"""
@ -215,7 +217,7 @@ class LiteLLMProvider(str, Enum):
CUSTOM = "CUSTOM"
class ImageGenProvider(str, Enum):
class ImageGenProvider(StrEnum):
"""
Enum for image generation providers supported by LiteLLM.
This is a subset of LLM providers only those that support image generation.
@ -233,7 +235,7 @@ class ImageGenProvider(str, Enum):
NSCALE = "NSCALE"
class LogLevel(str, Enum):
class LogLevel(StrEnum):
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
@ -241,13 +243,13 @@ class LogLevel(str, Enum):
CRITICAL = "CRITICAL"
class LogStatus(str, Enum):
class LogStatus(StrEnum):
IN_PROGRESS = "IN_PROGRESS"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
class IncentiveTaskType(str, Enum):
class IncentiveTaskType(StrEnum):
"""
Enum for incentive task types that users can complete to earn free pages.
Each task can only be completed once per user.
@ -298,7 +300,7 @@ INCENTIVE_TASKS_CONFIG = {
}
class Permission(str, Enum):
class Permission(StrEnum):
"""
Granular permissions for search space resources.
Use '*' (FULL_ACCESS) to grant all permissions.
@ -471,7 +473,7 @@ class BaseModel(Base):
id = Column(Integer, primary_key=True, index=True)
class NewChatMessageRole(str, Enum):
class NewChatMessageRole(StrEnum):
"""Role enum for new chat messages."""
USER = "user"
@ -479,7 +481,7 @@ class NewChatMessageRole(str, Enum):
SYSTEM = "system"
class ChatVisibility(str, Enum):
class ChatVisibility(StrEnum):
"""
Visibility/sharing level for chat threads.
@ -788,7 +790,7 @@ class ChatSessionState(BaseModel):
ai_responding_to_user = relationship("User")
class MemoryCategory(str, Enum):
class MemoryCategory(StrEnum):
"""Categories for user memories."""
# Using lowercase keys to match PostgreSQL enum values
@ -1317,6 +1319,12 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
config = Column(JSON, nullable=False)
# Summary generation (LLM-based) - disabled by default to save resources.
# When enabled, improves hybrid search quality at the cost of LLM calls.
enable_summary = Column(
Boolean, nullable=False, default=False, server_default="false"
)
# Periodic indexing fields
periodic_indexing_enabled = Column(Boolean, nullable=False, default=False)
indexing_frequency_minutes = Column(Integer, nullable=True)
@ -1850,10 +1858,37 @@ class RefreshToken(Base, TimestampMixin):
return not self.is_expired and not self.is_revoked
engine = create_async_engine(DATABASE_URL)
engine = create_async_engine(
DATABASE_URL,
pool_size=30,
max_overflow=150,
pool_recycle=1800,
pool_pre_ping=True,
pool_timeout=30,
)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def shielded_async_session():
"""Cancellation-safe async session context manager.
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
scope when a client disconnects. A plain ``async with async_session_maker()``
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
scope, orphaning the underlying database connection.
This wrapper ensures ``session.close()`` always completes by running it
inside ``anyio.CancelScope(shield=True)``.
"""
session = async_session_maker()
try:
yield session
finally:
with anyio.CancelScope(shield=True):
await session.close()
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes

View file

@ -0,0 +1,83 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
class UploadDocumentAdapter:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._service = IndexingPipelineService(session)
async def index(
self,
markdown_content: str,
filename: str,
etl_service: str,
search_space_id: int,
user_id: str,
llm,
should_summarize: bool = False,
) -> None:
connector_doc = ConnectorDocument(
title=filename,
source_markdown=markdown_content,
unique_id=filename,
document_type=DocumentType.FILE,
search_space_id=search_space_id,
created_by_id=user_id,
connector_id=None,
should_summarize=should_summarize,
should_use_code_chunker=False,
fallback_summary=markdown_content[:4000],
metadata={
"FILE_NAME": filename,
"ETL_SERVICE": etl_service,
},
)
documents = await self._service.prepare_for_indexing([connector_doc])
if not documents:
raise RuntimeError("prepare_for_indexing returned no documents")
indexed = await self._service.index(documents[0], connector_doc, llm)
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
indexed.content_needs_reindexing = False
await self._session.commit()
async def reindex(self, document: Document, llm) -> None:
"""Re-index an existing document after its source_markdown has been updated."""
if not document.source_markdown:
raise RuntimeError("Document has no source_markdown to reindex")
metadata = document.document_metadata or {}
connector_doc = ConnectorDocument(
title=document.title,
source_markdown=document.source_markdown,
unique_id=document.title,
document_type=document.document_type,
search_space_id=document.search_space_id,
created_by_id=str(document.created_by_id),
connector_id=document.connector_id,
should_summarize=True,
should_use_code_chunker=False,
fallback_summary=document.source_markdown[:4000],
metadata=metadata,
)
document.content_hash = compute_content_hash(connector_doc)
indexed = await self._service.index(document, connector_doc, llm)
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
raise RuntimeError(indexed.status.get("reason", "Reindexing failed"))
indexed.content_needs_reindexing = False
await self._session.commit()

View file

@ -0,0 +1,26 @@
from pydantic import BaseModel, Field, field_validator
from app.db import DocumentType
class ConnectorDocument(BaseModel):
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
title: str
source_markdown: str
unique_id: str
document_type: DocumentType
search_space_id: int = Field(gt=0)
should_summarize: bool = True
should_use_code_chunker: bool = False
fallback_summary: str | None = None
metadata: dict = {}
connector_id: int | None = None
created_by_id: str
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
@classmethod
def not_empty(cls, v: str, info) -> str:
if not v.strip():
raise ValueError(f"{info.field_name} must not be empty or whitespace")
return v

View file

@ -0,0 +1,9 @@
from app.config import config
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
"""Chunk a text string using the configured chunker and return the chunk texts."""
chunker = (
config.code_chunker_instance if use_code_chunker else config.chunker_instance
)
return [c.text for c in chunker.chunk(text)]

View file

@ -0,0 +1,3 @@
from app.utils.document_converters import embed_text
__all__ = ["embed_text"]

View file

@ -0,0 +1,15 @@
import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_content_hash(doc: ConnectorDocument) -> str:
"""Return a SHA-256 hash of the document's content scoped to its search space."""
combined = f"{doc.search_space_id}:{doc.source_markdown}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

View file

@ -0,0 +1,39 @@
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
from app.db import Document, DocumentStatus
async def rollback_and_persist_failure(
session: AsyncSession, document: Document, message: str
) -> None:
"""Roll back the current transaction and best-effort persist a failed status.
Called exclusively from except blocks must never raise, or the new exception
would chain with the original and mask it entirely.
"""
try:
await session.rollback()
except Exception:
return # Session is completely dead; nothing further we can do.
try:
await session.refresh(document)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.failed(message)
await session.commit()
except Exception:
pass # Best-effort; document will be retried on the next sync.
def attach_chunks_to_document(document: Document, chunks: list) -> None:
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
set_committed_value(document, "chunks", chunks)
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)

View file

@ -0,0 +1,30 @@
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import optimize_content_for_context_window
async def summarize_document(
source_markdown: str, llm, metadata: dict | None = None
) -> str:
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
model_name = getattr(llm, "model", "gpt-3.5-turbo")
optimized_content = optimize_content_for_context_window(
source_markdown, metadata, model_name
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
content_with_metadata = (
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
)
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
summary_content = summary_result.content
if metadata:
metadata_parts = ["# DOCUMENT METADATA"]
for key, value in metadata.items():
if value:
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
metadata_section = "\n".join(metadata_parts)
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
return summary_content

View file

@ -0,0 +1,146 @@
from litellm.exceptions import (
APIConnectionError,
APIResponseValidationError,
AuthenticationError,
BadGatewayError,
BadRequestError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
ServiceUnavailableError,
Timeout,
UnprocessableEntityError,
)
from sqlalchemy.exc import IntegrityError as IntegrityError
# Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = (
RateLimitError,
Timeout,
ServiceUnavailableError,
BadGatewayError,
InternalServerError,
APIConnectionError,
)
PERMANENT_LLM_ERRORS = (
AuthenticationError,
PermissionDeniedError,
NotFoundError,
BadRequestError,
UnprocessableEntityError,
APIResponseValidationError,
)
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
EMBEDDING_ERRORS = (
RuntimeError, # local device failure or API backend normalization
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
)
class PipelineMessages:
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = (
"Document exceeds the LLM context window even after optimization."
)
LLM_RESPONSE = "LLM returned an invalid response."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = (
"Document exceeds the LLM context window even after optimization."
)
LLM_RESPONSE = "LLM returned an invalid response."
EMBEDDING_FAILED = (
"Embedding failed. Check your embedding model configuration or service."
)
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
EMBEDDING_FAILED = (
"Embedding failed. Check your embedding model configuration or service."
)
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
def safe_exception_message(exc: Exception) -> str:
try:
return str(exc)
except Exception:
return "Something went wrong during indexing. Error details could not be retrieved."
def llm_retryable_message(exc: Exception) -> str:
try:
if isinstance(exc, RateLimitError):
return PipelineMessages.RATE_LIMIT
if isinstance(exc, Timeout):
return PipelineMessages.LLM_TIMEOUT
if isinstance(exc, ServiceUnavailableError):
return PipelineMessages.LLM_UNAVAILABLE
if isinstance(exc, BadGatewayError):
return PipelineMessages.LLM_BAD_GATEWAY
if isinstance(exc, InternalServerError):
return PipelineMessages.LLM_SERVER_ERROR
if isinstance(exc, APIConnectionError):
return PipelineMessages.LLM_CONNECTION
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def llm_permanent_message(exc: Exception) -> str:
try:
if isinstance(exc, AuthenticationError):
return PipelineMessages.LLM_AUTH
if isinstance(exc, PermissionDeniedError):
return PipelineMessages.LLM_PERMISSION
if isinstance(exc, NotFoundError):
return PipelineMessages.LLM_NOT_FOUND
if isinstance(exc, BadRequestError):
return PipelineMessages.LLM_BAD_REQUEST
if isinstance(exc, UnprocessableEntityError):
return PipelineMessages.LLM_UNPROCESSABLE
if isinstance(exc, APIResponseValidationError):
return PipelineMessages.LLM_RESPONSE
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def embedding_message(exc: Exception) -> str:
try:
if isinstance(exc, RuntimeError):
return PipelineMessages.EMBEDDING_FAILED
if isinstance(exc, OSError):
return PipelineMessages.EMBEDDING_MODEL
if isinstance(exc, MemoryError):
return PipelineMessages.EMBEDDING_MEMORY
return safe_exception_message(exc)
except Exception:
return "Something went wrong when generating the embedding."

View file

@ -0,0 +1,272 @@
import contextlib
import time
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_text
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
attach_chunks_to_document,
rollback_and_persist_failure,
)
from app.indexing_pipeline.document_summarizer import summarize_document
from app.indexing_pipeline.exceptions import (
EMBEDDING_ERRORS,
PERMANENT_LLM_ERRORS,
RETRYABLE_LLM_ERRORS,
PipelineMessages,
embedding_message,
llm_permanent_message,
llm_retryable_message,
safe_exception_message,
)
from app.indexing_pipeline.pipeline_logger import (
PipelineLogContext,
log_batch_aborted,
log_chunking_overflow,
log_doc_skipped_unknown,
log_document_queued,
log_document_requeued,
log_document_updated,
log_embedding_error,
log_index_started,
log_index_success,
log_permanent_llm_error,
log_race_condition,
log_retryable_llm_error,
log_unexpected_error,
)
from app.utils.perf import get_perf_logger
class IndexingPipelineService:
"""Single pipeline for indexing connector documents. All connectors use this service."""
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument]
) -> list[Document]:
"""
Persist new documents and detect changes, returning only those that need indexing.
"""
perf = get_perf_logger()
t0 = time.perf_counter()
documents = []
seen_hashes: set[str] = set()
batch_ctx = PipelineLogContext(
connector_id=connector_docs[0].connector_id if connector_docs else 0,
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
unique_id="batch",
)
for connector_doc in connector_docs:
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
)
try:
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
content_hash = compute_content_hash(connector_doc)
if unique_identifier_hash in seen_hashes:
continue
seen_hashes.add(unique_identifier_hash)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == unique_identifier_hash
)
)
existing = result.scalars().first()
if existing is not None:
if existing.content_hash == content_hash:
if existing.title != connector_doc.title:
existing.title = connector_doc.title
existing.updated_at = datetime.now(UTC)
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.pending()
existing.updated_at = datetime.now(UTC)
documents.append(existing)
log_document_requeued(ctx)
continue
existing.title = connector_doc.title
existing.content_hash = content_hash
existing.source_markdown = connector_doc.source_markdown
existing.document_metadata = connector_doc.metadata
existing.updated_at = datetime.now(UTC)
existing.status = DocumentStatus.pending()
documents.append(existing)
log_document_updated(ctx)
continue
duplicate = await self.session.execute(
select(Document).filter(Document.content_hash == content_hash)
)
if duplicate.scalars().first() is not None:
continue
document = Document(
title=connector_doc.title,
document_type=connector_doc.document_type,
content="Pending...",
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=connector_doc.source_markdown,
document_metadata=connector_doc.metadata,
search_space_id=connector_doc.search_space_id,
connector_id=connector_doc.connector_id,
created_by_id=connector_doc.created_by_id,
updated_at=datetime.now(UTC),
status=DocumentStatus.pending(),
)
self.session.add(document)
documents.append(document)
log_document_queued(ctx)
except Exception as e:
log_doc_skipped_unknown(ctx, e)
try:
await self.session.commit()
perf.info(
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
time.perf_counter() - t0,
len(connector_docs),
len(documents),
)
return documents
except IntegrityError:
log_race_condition(batch_ctx)
await self.session.rollback()
return []
except Exception as e:
log_batch_aborted(batch_ctx, e)
await self.session.rollback()
return []
async def index(
self, document: Document, connector_doc: ConnectorDocument, llm
) -> Document:
"""
Run summarization, embedding, and chunking for a document and persist the results.
"""
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
doc_id=document.id,
)
perf = get_perf_logger()
t_index = time.perf_counter()
try:
log_index_started(ctx)
document.status = DocumentStatus.processing()
await self.session.commit()
t_step = time.perf_counter()
if connector_doc.should_summarize and llm is not None:
content = await summarize_document(
connector_doc.source_markdown, llm, connector_doc.metadata
)
perf.info(
"[indexing] summarize_document doc=%d in %.3fs",
document.id,
time.perf_counter() - t_step,
)
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
else:
content = connector_doc.source_markdown
t_step = time.perf_counter()
embedding = embed_text(content)
perf.debug(
"[indexing] embed_text (summary) doc=%d in %.3fs",
document.id,
time.perf_counter() - t_step,
)
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
t_step = time.perf_counter()
chunks = [
Chunk(content=text, embedding=embed_text(text))
for text in chunk_text(
connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker,
)
]
perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id,
len(chunks),
time.perf_counter() - t_step,
)
document.content = content
document.embedding = embedding
attach_chunks_to_document(document, chunks)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
perf.info(
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
document.id,
len(chunks),
time.perf_counter() - t_index,
)
log_index_success(ctx, chunk_count=len(chunks))
except RETRYABLE_LLM_ERRORS as e:
log_retryable_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_retryable_message(e)
)
except PERMANENT_LLM_ERRORS as e:
log_permanent_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_permanent_message(e)
)
except RecursionError as e:
log_chunking_overflow(ctx, e)
await rollback_and_persist_failure(
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
)
except EMBEDDING_ERRORS as e:
log_embedding_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, embedding_message(e)
)
except Exception as e:
log_unexpected_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, safe_exception_message(e)
)
with contextlib.suppress(Exception):
await self.session.refresh(document)
return document

View file

@ -0,0 +1,126 @@
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class PipelineLogContext:
connector_id: int | None
search_space_id: int
unique_id: str # always available from ConnectorDocument
doc_id: int | None = None # set once the DB row exists (index phase only)
class LogMessages:
# prepare_for_indexing
DOCUMENT_QUEUED = "New document queued for indexing."
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
# index
INDEX_STARTED = "Document indexing started."
INDEX_SUCCESS = "Document indexed successfully."
LLM_RETRYABLE = (
"Retryable LLM error — document marked failed, will retry on next sync."
)
LLM_PERMANENT = "Permanent LLM error — document marked failed."
EMBEDDING_FAILED = "Embedding error — document marked failed."
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
UNEXPECTED = "Unexpected error — document marked failed."
def _format_context(ctx: PipelineLogContext) -> str:
parts = [
f"connector_id={ctx.connector_id}",
f"search_space_id={ctx.search_space_id}",
f"unique_id={ctx.unique_id}",
]
if ctx.doc_id is not None:
parts.append(f"doc_id={ctx.doc_id}")
return " ".join(parts)
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
try:
parts = [msg, _format_context(ctx)]
for key, val in extra.items():
parts.append(f"{key}={val}")
return " ".join(parts)
except Exception:
return msg
def _safe_log(
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
) -> None:
# Logging must never raise — a broken log call inside an except block would
# chain with the original exception and mask it entirely.
try:
message = _build_message(msg, ctx, **extra)
level_fn(message, exc_info=exc_info)
except Exception:
pass
# ── prepare_for_indexing ──────────────────────────────────────────────────────
def log_document_queued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
def log_document_updated(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
def log_document_requeued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
)
def log_race_condition(ctx: PipelineLogContext) -> None:
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
# ── index ─────────────────────────────────────────────────────────────────────
def log_index_started(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)

View file

@ -1,5 +1,10 @@
import time
from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
class ChucksHybridSearchRetriever:
def __init__(self, db_session):
@ -38,9 +43,17 @@ class ChucksHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Get embedding for the query
embedding_model = config.embedding_model_instance
t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
perf.debug(
"[chunk_search] vector_search embedding in %.3fs",
time.perf_counter() - t_embed,
)
# Build the query filtered by search space
query = (
@ -60,8 +73,16 @@ class ChucksHybridSearchRetriever:
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
# Execute the query
t_db = time.perf_counter()
result = await self.db_session.execute(query)
chunks = result.scalars().all()
perf.info(
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
time.perf_counter() - t_db,
len(chunks),
time.perf_counter() - t0,
search_space_id,
)
return chunks
@ -91,6 +112,9 @@ class ChucksHybridSearchRetriever:
from app.db import Chunk, Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery("english", query_text)
@ -118,6 +142,12 @@ class ChucksHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
perf.info(
"[chunk_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0,
len(chunks),
search_space_id,
)
return chunks
@ -129,6 +159,7 @@ class ChucksHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@ -143,6 +174,7 @@ class ChucksHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
Returns:
List of dictionaries containing document data and relevance scores. Each dict contains:
@ -157,9 +189,17 @@ class ChucksHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document, DocumentType
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
perf = get_perf_logger()
t0 = time.perf_counter()
if query_embedding is None:
embedding_model = config.embedding_model_instance
t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
perf.debug(
"[chunk_search] hybrid_search embedding in %.3fs",
time.perf_counter() - t_embed,
)
# RRF constants
k = 60
@ -254,9 +294,17 @@ class ChucksHybridSearchRetriever:
.limit(top_k)
)
# Execute the query
# Execute the RRF query
t_rrf = time.perf_counter()
result = await self.db_session.execute(final_query)
chunks_with_scores = result.all()
perf.info(
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
time.perf_counter() - t_rrf,
len(chunks_with_scores),
search_space_id,
document_type,
)
# If no results were found, return an empty list
if not chunks_with_scores:
@ -300,8 +348,9 @@ class ChucksHybridSearchRetriever:
if not doc_ids:
return []
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
# any chunk from those documents.
# Fetch chunks for selected documents. We cap per document to avoid
# loading hundreds of chunks for a single large file while still
# ensuring the chunks that matched the RRF query are always included.
chunk_query = (
select(Chunk)
.options(joinedload(Chunk.document))
@ -311,7 +360,20 @@ class ChucksHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunk_query)
all_chunks = chunks_result.scalars().all()
raw_chunks = chunks_result.scalars().all()
matched_chunk_ids: set[int] = {
item["chunk_id"] for item in serialized_chunk_results
}
doc_chunk_counts: dict[int, int] = {}
all_chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
all_chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble final doc-grouped results in the same order as doc_ids
doc_map: dict[int, dict] = {
@ -354,4 +416,11 @@ class ChucksHybridSearchRetriever:
)
final_docs.append(entry)
perf.info(
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
)
return final_docs

View file

@ -1,5 +1,10 @@
import time
from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
class DocumentHybridSearchRetriever:
def __init__(self, db_session):
@ -38,6 +43,9 @@ class DocumentHybridSearchRetriever:
from app.config import config
from app.db import Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
@ -63,6 +71,12 @@ class DocumentHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
perf.info(
"[doc_search] vector_search in %.3fs results=%d space=%d",
time.perf_counter() - t0,
len(documents),
search_space_id,
)
return documents
@ -92,6 +106,9 @@ class DocumentHybridSearchRetriever:
from app.db import Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery("english", query_text)
@ -118,6 +135,12 @@ class DocumentHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
perf.info(
"[doc_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0,
len(documents),
search_space_id,
)
return documents
@ -129,6 +152,7 @@ class DocumentHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@ -143,7 +167,7 @@ class DocumentHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
"""
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
@ -151,9 +175,12 @@ class DocumentHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document, DocumentType
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
perf = get_perf_logger()
t0 = time.perf_counter()
if query_embedding is None:
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# RRF constants
k = 60
@ -254,7 +281,8 @@ class DocumentHybridSearchRetriever:
# Collect document IDs for chunk fetching
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
# Fetch ALL chunks for these documents in a single query
# Fetch chunks for these documents, capped per document to avoid
# loading hundreds of chunks for a single large file.
chunks_query = (
select(Chunk)
.options(joinedload(Chunk.document))
@ -262,7 +290,16 @@ class DocumentHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all()
raw_chunks = chunks_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _MAX_FETCH_CHUNKS_PER_DOC:
chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble doc-grouped results
doc_map: dict[int, dict] = {
@ -303,4 +340,11 @@ class DocumentHybridSearchRetriever:
)
final_docs.append(entry)
perf.info(
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
)
return final_docs

View file

@ -36,6 +36,7 @@ from .podcasts_routes import router as podcasts_router
from .public_chat_routes import router as public_chat_router
from .rbac_routes import router as rbac_router
from .reports_routes import router as reports_router
from .sandbox_routes import router as sandbox_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .search_spaces_routes import router as search_spaces_router
from .slack_add_connector_route import router as slack_add_connector_router
@ -50,6 +51,7 @@ router.include_router(editor_router)
router.include_router(documents_router)
router.include_router(notes_router)
router.include_router(new_chat_router) # Chat with assistant-ui persistence
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
router.include_router(chat_comments_router)
router.include_router(podcasts_router) # Podcast task status and audio
router.include_router(reports_router) # Report CRUD and export (PDF/DOCX)

View file

@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.schemas.chat_comments import (
CommentBatchRequest,
CommentBatchResponse,
CommentCreateRequest,
CommentListResponse,
CommentReplyResponse,
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
create_reply,
delete_comment,
get_comments_for_message,
get_comments_for_messages_batch,
get_user_mentions,
update_comment,
)
@ -27,6 +30,16 @@ from app.users import current_active_user
router = APIRouter()
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
async def batch_list_comments(
request: CommentBatchRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Batch-fetch comments for multiple messages in one request."""
return await get_comments_for_messages_batch(session, request.message_ids, user)
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments(
message_id: int,

View file

@ -28,6 +28,7 @@ from app.schemas import (
DocumentWithChunksRead,
PaginatedResponse,
)
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -44,6 +45,10 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
router = APIRouter()
MAX_FILES_PER_UPLOAD = 10
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
@router.post("/documents")
async def create_documents(
@ -114,8 +119,10 @@ async def create_documents(
async def create_documents_file_upload(
files: list[UploadFile],
search_space_id: int = Form(...),
should_summarize: bool = Form(False),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
):
"""
Upload files as documents with real-time status tracking.
@ -126,6 +133,8 @@ async def create_documents_file_upload(
Requires DOCUMENTS_CREATE permission.
"""
import os
import tempfile
from datetime import datetime
from app.db import DocumentStatus
@ -136,7 +145,6 @@ async def create_documents_file_upload(
from app.utils.document_converters import generate_unique_identifier_hash
try:
# Check permission
await check_permission(
session,
user,
@ -148,51 +156,88 @@ async def create_documents_file_upload(
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > MAX_FILES_PER_UPLOAD:
raise HTTPException(
status_code=413,
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
)
total_size = 0
for file in files:
file_size = file.size or 0
if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
total_size += file_size
if total_size > MAX_TOTAL_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
# ===== Read all files concurrently to avoid blocking the event loop =====
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
"""Read upload content and write to temp file off the event loop."""
content = await file.read()
file_size = len(content)
filename = file.filename or "unknown"
if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
def _write_temp() -> str:
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(filename)[1]
) as tmp:
tmp.write(content)
return tmp.name
temp_path = await asyncio.to_thread(_write_temp)
return temp_path, filename, file_size
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
actual_total_size = sum(size for _, _, size in saved_files)
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
for temp_path, _, _ in saved_files:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
# ===== PHASE 1: Create pending documents for all files =====
created_documents: list[Document] = []
files_to_process: list[
tuple[Document, str, str]
] = [] # (document, temp_path, filename)
files_to_process: list[tuple[Document, str, str]] = []
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
# ===== PHASE 1: Create pending documents for all files =====
# This makes ALL documents visible in the UI immediately with pending status
for file in files:
for temp_path, filename, file_size in saved_files:
try:
import os
import tempfile
# Save file to temp location
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(file.filename or "")[1]
) as temp_file:
temp_path = temp_file.name
content = await file.read()
with open(temp_path, "wb") as f:
f.write(content)
file_size = len(content)
# Generate unique identifier for deduplication check
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, file.filename or "unknown", search_space_id
DocumentType.FILE, filename, search_space_id
)
# Check if document already exists (by unique identifier)
existing = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing:
if DocumentStatus.is_state(existing.status, DocumentStatus.READY):
# True duplicate — content already indexed, skip
os.unlink(temp_path)
skipped_duplicates += 1
duplicate_document_ids.append(existing.id)
continue
# Existing document is stuck (failed/pending/processing)
# Reset it to pending and re-dispatch for processing
existing.status = DocumentStatus.pending()
existing.content = "Processing..."
existing.document_metadata = {
@ -202,61 +247,53 @@ async def create_documents_file_upload(
}
existing.updated_at = get_current_timestamp()
created_documents.append(existing)
files_to_process.append(
(existing, temp_path, file.filename or "unknown")
)
files_to_process.append((existing, temp_path, filename))
continue
# Create pending document (visible immediately in UI via ElectricSQL)
document = Document(
search_space_id=search_space_id,
title=file.filename or "Uploaded File",
title=filename if filename != "unknown" else "Uploaded File",
document_type=DocumentType.FILE,
document_metadata={
"FILE_NAME": file.filename,
"FILE_NAME": filename,
"file_size": file_size,
"upload_time": datetime.now().isoformat(),
},
content="Processing...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary, updated when ready
content="Processing...",
content_hash=unique_identifier_hash,
unique_identifier_hash=unique_identifier_hash,
embedding=None,
status=DocumentStatus.pending(), # Shows "pending" in UI
status=DocumentStatus.pending(),
updated_at=get_current_timestamp(),
created_by_id=str(user.id),
)
session.add(document)
created_documents.append(document)
files_to_process.append(
(document, temp_path, file.filename or "unknown")
)
files_to_process.append((document, temp_path, filename))
except HTTPException:
raise
except Exception as e:
os.unlink(temp_path)
raise HTTPException(
status_code=422,
detail=f"Failed to process file {file.filename}: {e!s}",
detail=f"Failed to process file {filename}: {e!s}",
) from e
# Commit all pending documents - they appear in UI immediately via ElectricSQL
if created_documents:
await session.commit()
# Refresh to get generated IDs
for doc in created_documents:
await session.refresh(doc)
# ===== PHASE 2: Dispatch Celery tasks for each file =====
# Each task will update document status: pending → processing → ready/failed
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
)
# ===== PHASE 2: Dispatch tasks for each file =====
for document, temp_path, filename in files_to_process:
process_file_upload_with_document_task.delay(
await dispatcher.dispatch_file_processing(
document_id=document.id,
temp_path=temp_path,
filename=filename,
search_space_id=search_space_id,
user_id=str(user.id),
should_summarize=should_summarize,
)
return {

View file

@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- POST /threads/{thread_id}/messages - Append message
"""
import asyncio
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
@ -30,6 +32,7 @@ from app.db import (
SearchSpace,
User,
get_async_session,
shielded_async_session,
)
from app.schemas.new_chat import (
NewChatMessageAppend,
@ -52,9 +55,50 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
router = APIRouter()
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
delete_local_sandbox_files,
delete_sandbox,
is_sandbox_enabled,
)
if not is_sandbox_enabled():
return
async def _bg() -> None:
try:
await delete_sandbox(thread_id)
except Exception:
_logger.warning(
"Background sandbox delete failed for thread %s",
thread_id,
exc_info=True,
)
try:
delete_local_sandbox_files(thread_id)
except Exception:
_logger.warning(
"Local sandbox file cleanup failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_bg())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def check_thread_access(
session: AsyncSession,
thread: NewChatThread,
@ -648,6 +692,9 @@ async def delete_thread(
await session.delete(db_thread)
await session.commit()
_try_delete_sandbox(thread_id)
return {"message": "Thread deleted successfully"}
except HTTPException:
@ -1046,13 +1093,18 @@ async def handle_new_chat(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
# Close the dependency session now so its connection returns to
# the pool before streaming begins. Without this, Starlette's
# BaseHTTPMiddleware cancels the scope on client disconnect and
# the dependency generator's __aexit__ never runs, orphaning the
# connection (the "Exception terminating connection" errors).
await session.close()
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@ -1277,6 +1329,7 @@ async def regenerate_response(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
@ -1287,7 +1340,6 @@ async def regenerate_response(
user_query=user_query_to_use,
search_space_id=request.search_space_id,
chat_id=thread_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@ -1298,29 +1350,35 @@ async def regenerate_response(
current_user_display_name=user.display_name or "A team member",
):
yield chunk
# If we get here, streaming completed successfully
streaming_completed = True
finally:
# Only delete old messages if streaming completed successfully
# This ensures we don't lose data on streaming failures
if streaming_completed and messages_to_delete:
# Only delete old messages if streaming completed successfully.
# Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete:
try:
for msg in messages_to_delete:
await session.delete(msg)
await session.commit()
async with shielded_async_session() as cleanup_session:
for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == msg_id
)
)
_msg = _res.scalars().first()
if _msg:
await cleanup_session.delete(_msg)
await cleanup_session.commit()
# Delete any public snapshots that contain the modified messages
from app.services.public_chat_service import (
delete_affected_snapshots,
)
from app.services.public_chat_service import (
delete_affected_snapshots,
)
await delete_affected_snapshots(
session, thread_id, message_ids_to_delete
)
await delete_affected_snapshots(
cleanup_session, thread_id, message_ids_to_delete
)
except Exception as cleanup_error:
# Log but don't fail - the new messages are already streamed
print(
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
_logger.warning(
"[regenerate] Failed to delete old messages: %s",
cleanup_error,
)
# Return streaming response with checkpoint_id for rewinding
@ -1394,13 +1452,13 @@ async def resume_chat(
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
await session.close()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,

View file

@ -17,7 +17,7 @@ import logging
import os
import re
import tempfile
from enum import Enum
from enum import StrEnum
import pypandoc
import typst
@ -46,7 +46,7 @@ router = APIRouter()
MAX_REPORT_LIST_LIMIT = 500
class ExportFormat(str, Enum):
class ExportFormat(StrEnum):
PDF = "pdf"
DOCX = "docx"

View file

@ -0,0 +1,105 @@
"""Routes for downloading files from Daytona sandbox environments."""
from __future__ import annotations
import asyncio
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import NewChatThread, Permission, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
router = APIRouter()
MIME_TYPES: dict[str, str] = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
".svg": "image/svg+xml",
".pdf": "application/pdf",
".csv": "text/csv",
".json": "application/json",
".txt": "text/plain",
".html": "text/html",
".md": "text/markdown",
".py": "text/x-python",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".zip": "application/zip",
}
def _guess_media_type(filename: str) -> str:
ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else ""
return MIME_TYPES.get(ext, "application/octet-stream")
@router.get("/threads/{thread_id}/sandbox/download")
async def download_sandbox_file(
thread_id: int,
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Download a file from the Daytona sandbox associated with a chat thread."""
from app.agents.new_chat.sandbox import get_or_create_sandbox, is_sandbox_enabled
if not is_sandbox_enabled():
raise HTTPException(status_code=404, detail="Sandbox is not enabled")
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to access files in this thread",
)
from app.agents.new_chat.sandbox import get_local_sandbox_file
# Prefer locally-persisted copy (sandbox may already be deleted)
local_content = get_local_sandbox_file(thread_id, path)
if local_content is not None:
filename = path.rsplit("/", 1)[-1] if "/" in path else path
media_type = _guess_media_type(filename)
return Response(
content=local_content,
media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
# Fall back to live sandbox download
try:
sandbox = await get_or_create_sandbox(thread_id)
raw_sandbox = sandbox._sandbox
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
except Exception as exc:
logger.warning("Sandbox file download failed for %s: %s", path, exc)
raise HTTPException(
status_code=404, detail=f"Could not download file: {exc}"
) from exc
filename = path.rsplit("/", 1)[-1] if "/" in path else path
media_type = _guess_media_type(filename)
return Response(
content=content,
media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

View file

@ -2735,7 +2735,10 @@ async def create_mcp_connector(
f"for user {user.id} in search space {search_space_id}"
)
# Convert to read schema
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2910,6 +2913,10 @@ async def update_mcp_connector(
logger.info(f"Updated MCP connector {connector_id}")
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(connector.search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2960,9 +2967,14 @@ async def delete_mcp_connector(
"You don't have permission to delete this connector",
)
search_space_id = connector.search_space_id
await session.delete(connector)
await session.commit()
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
logger.info(f"Deleted MCP connector {connector_id}")
except HTTPException:

View file

@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
total_count: int
class CommentBatchRequest(BaseModel):
"""Request for batch-fetching comments for multiple messages."""
message_ids: list[int] = Field(..., min_length=1, max_length=200)
class CommentBatchResponse(BaseModel):
"""Batch response keyed by message_id."""
comments_by_message: dict[int, CommentListResponse]
# =============================================================================
# Mention Schemas
# =============================================================================

View file

@ -1,13 +1,13 @@
"""Podcast schemas for API responses."""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel
class PodcastStatusEnum(str, Enum):
class PodcastStatusEnum(StrEnum):
PENDING = "pending"
GENERATING = "generating"
READY = "ready"

View file

@ -16,6 +16,7 @@ class SearchSourceConnectorBase(BaseModel):
is_indexable: bool
last_indexed_at: datetime | None = None
config: dict[str, Any]
enable_summary: bool = False
periodic_indexing_enabled: bool = False
indexing_frequency_minutes: int | None = None
next_scheduled_at: datetime | None = None
@ -65,6 +66,7 @@ class SearchSourceConnectorUpdate(BaseModel):
is_indexable: bool | None = None
last_indexed_at: datetime | None = None
config: dict[str, Any] | None = None
enable_summary: bool | None = None
periodic_indexing_enabled: bool | None = None
indexing_frequency_minutes: int | None = None
next_scheduled_at: datetime | None = None

View file

@ -22,6 +22,7 @@ from app.db import (
)
from app.schemas.chat_comments import (
AuthorResponse,
CommentBatchResponse,
CommentListResponse,
CommentReplyResponse,
CommentResponse,
@ -264,6 +265,146 @@ async def get_comments_for_message(
)
async def get_comments_for_messages_batch(
session: AsyncSession,
message_ids: list[int],
user: User,
) -> CommentBatchResponse:
"""
Batch-fetch comments for multiple messages in a single DB round-trip.
Validates that all messages exist and belong to search spaces the user
can read comments in, then loads all comments with eager-loaded authors
and replies.
"""
if not message_ids:
return CommentBatchResponse(comments_by_message={})
unique_ids = list(set(message_ids))
result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.thread))
.filter(NewChatMessage.id.in_(unique_ids))
)
messages = result.scalars().all()
msg_map = {m.id: m for m in messages}
search_space_ids = {m.thread.search_space_id for m in messages}
permissions_cache: dict[int, set] = {}
for ss_id in search_space_ids:
await check_permission(
session,
user,
ss_id,
Permission.COMMENTS_READ.value,
"You don't have permission to read comments in this search space",
)
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
result = await session.execute(
select(ChatComment)
.options(
selectinload(ChatComment.author),
selectinload(ChatComment.replies).selectinload(ChatComment.author),
)
.filter(
ChatComment.message_id.in_(unique_ids),
ChatComment.parent_id.is_(None),
)
.order_by(ChatComment.created_at)
)
top_level_comments = result.scalars().all()
all_mentioned_uuids: set[UUID] = set()
for comment in top_level_comments:
all_mentioned_uuids.update(parse_mentions(comment.content))
for reply in comment.replies:
all_mentioned_uuids.update(parse_mentions(reply.content))
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
for comment in top_level_comments:
comments_by_msg.setdefault(comment.message_id, []).append(comment)
comments_by_message: dict[int, CommentListResponse] = {}
for mid in unique_ids:
msg = msg_map.get(mid)
if msg is None:
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
continue
ss_id = msg.thread.search_space_id
user_perms = permissions_cache.get(ss_id, set())
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
comment_responses = []
for comment in comments_by_msg.get(mid, []):
author = None
if comment.author:
author = AuthorResponse(
id=comment.author.id,
display_name=comment.author.display_name,
avatar_url=comment.author.avatar_url,
email=comment.author.email,
)
replies = []
for reply in sorted(comment.replies, key=lambda r: r.created_at):
reply_author = None
if reply.author:
reply_author = AuthorResponse(
id=reply.author.id,
display_name=reply.author.display_name,
avatar_url=reply.author.avatar_url,
email=reply.author.email,
)
is_reply_author = (
reply.author_id == user.id if reply.author_id else False
)
replies.append(
CommentReplyResponse(
id=reply.id,
content=reply.content,
content_rendered=render_mentions(reply.content, user_names),
author=reply_author,
created_at=reply.created_at,
updated_at=reply.updated_at,
is_edited=reply.updated_at > reply.created_at,
can_edit=is_reply_author,
can_delete=is_reply_author or can_delete_any,
)
)
is_comment_author = (
comment.author_id == user.id if comment.author_id else False
)
comment_responses.append(
CommentResponse(
id=comment.id,
message_id=comment.message_id,
content=comment.content,
content_rendered=render_mentions(comment.content, user_names),
author=author,
created_at=comment.created_at,
updated_at=comment.updated_at,
is_edited=comment.updated_at > comment.created_at,
can_edit=is_comment_author,
can_delete=is_comment_author or can_delete_any,
reply_count=len(replies),
replies=replies,
)
)
comments_by_message[mid] = CommentListResponse(
comments=comment_responses,
total_count=len(comment_responses),
)
return CommentBatchResponse(comments_by_message=comments_by_message)
async def create_comment(
session: AsyncSession,
message_id: int,

View file

@ -1,4 +1,5 @@
import asyncio
import time
from datetime import datetime
from typing import Any
from urllib.parse import urljoin
@ -15,9 +16,11 @@ from app.db import (
Document,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
from app.utils.perf import get_perf_logger
class ConnectorService:
@ -221,6 +224,7 @@ class ConnectorService:
top_k: int = 20,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]:
"""
Perform combined search using both chunk-based and document-based hybrid search,
@ -246,34 +250,60 @@ class ConnectorService:
Returns:
List of combined and deduplicated document results
"""
from app.config import config
perf = get_perf_logger()
t0 = time.perf_counter()
# RRF constant
k = 60
# Get more results from each retriever for better fusion
retriever_top_k = top_k * 2
# IMPORTANT:
# These retrievers share the same AsyncSession. AsyncSession does not permit
# concurrent awaits that require DB IO on the same session/connection.
# Running these in parallel can raise:
# "This session is provisioning a new connection; concurrent operations are not permitted"
#
# So we run them sequentially.
chunk_results = await self.chunk_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
# Reuse caller-provided embedding or compute once for both retrievers.
if query_embedding is None:
t_embed = time.perf_counter()
query_embedding = config.embedding_model_instance.embed(query_text)
perf.info(
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
time.perf_counter() - t_embed,
document_type,
)
search_kwargs = {
"query_text": query_text,
"top_k": retriever_top_k,
"search_space_id": search_space_id,
"document_type": document_type,
"start_date": start_date,
"end_date": end_date,
"query_embedding": query_embedding,
}
# Run chunk and document retrievers in parallel using separate DB sessions
# so they don't contend on a shared AsyncSession connection.
async def _run_chunk_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = ChucksHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
async def _run_doc_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = DocumentHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
t_parallel = time.perf_counter()
chunk_results, doc_results = await asyncio.gather(
_run_chunk_search(), _run_doc_search()
)
doc_results = await self.document_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
perf.info(
"[connector_svc] _combined_rrf parallel retrievers in %.3fs "
"chunk_results=%d doc_results=%d type=%s",
time.perf_counter() - t_parallel,
len(chunk_results),
len(doc_results),
document_type,
)
# Helper to extract document_id from our doc-grouped result
@ -335,6 +365,13 @@ class ConnectorService:
result["chunks"] = doc_data[did]["chunks"]
combined_results.append(result)
perf.info(
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
time.perf_counter() - t0,
len(combined_results),
document_type,
search_space_id,
)
return combined_results
def _get_doc_url(self, metadata: dict[str, Any]) -> str:
@ -1303,10 +1340,9 @@ class ConnectorService:
sources_list = self._build_chunk_sources_from_documents(
github_docs,
description_fn=lambda chunk, _doc_info, metadata: metadata.get(
"description"
)
or chunk.get("content", ""),
description_fn=lambda chunk, _doc_info, metadata: (
metadata.get("description") or chunk.get("content", "")
),
url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "",
)

View file

@ -4,12 +4,12 @@ from datetime import datetime
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.linear_connector import LinearConnector
from app.db import Chunk, Document
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
)
@ -80,7 +80,7 @@ class LinearKBSyncService:
state = formatted_issue.get("state", "Unknown")
priority = issue_raw.get("priorityLabel", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
description = formatted_issue.get("description", "")
formatted_issue.get("description", "")
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
@ -100,18 +100,10 @@ class LinearKBSyncService:
issue_content, user_llm, document_metadata_for_summary
)
else:
if description and len(description) > 1000:
description = description[:997] + "..."
summary_content = (
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n"
)
if description:
summary_content += f"Description: {description}\n\n"
summary_content += f"Comments: {comment_count}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
await self.db_session.execute(
delete(Chunk).where(Chunk.document_id == document.id)

View file

@ -12,16 +12,42 @@ synchronous ChatLiteLLM-like interface and async methods.
"""
import logging
import re
import time
from typing import Any
import litellm
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from litellm import Router
from litellm.exceptions import (
BadRequestError as LiteLLMBadRequestError,
ContextWindowExceededError,
)
from app.utils.perf import get_perf_logger
litellm.json_logs = False
litellm.store_audit_logs = False
logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
r"maximum context length|token.{0,20}(limit|exceed)|"
r"too many tokens|reduce the length)",
re.IGNORECASE,
)
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
"""Check if a BadRequestError is actually a context window overflow."""
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
@ -133,26 +159,95 @@ class LLMRouterService:
# Merge with provided settings
final_settings = {**default_settings, **instance._router_settings}
# Build a "auto-large" fallback group with deployments whose context
# window exceeds the smallest deployment. This lets the router
# automatically fall back to a bigger-context model when gpt-4o (128K)
# hits ContextWindowExceededError.
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
try:
instance._router = Router(
model_list=model_list,
routing_strategy=final_settings.get(
router_kwargs: dict[str, Any] = {
"model_list": full_model_list,
"routing_strategy": final_settings.get(
"routing_strategy", "usage-based-routing"
),
num_retries=final_settings.get("num_retries", 3),
allowed_fails=final_settings.get("allowed_fails", 3),
cooldown_time=final_settings.get("cooldown_time", 60),
set_verbose=False, # Disable verbose logging in production
)
"num_retries": final_settings.get("num_retries", 3),
"allowed_fails": final_settings.get("allowed_fails", 3),
"cooldown_time": final_settings.get("cooldown_time", 60),
"set_verbose": False,
}
if ctx_fallbacks:
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
instance._router = Router(**router_kwargs)
instance._initialized = True
logger.info(
f"LLM Router initialized with {len(model_list)} deployments, "
f"strategy: {final_settings.get('routing_strategy')}"
"LLM Router initialized with %d deployments, "
"strategy: %s, context_window_fallbacks: %s",
len(model_list),
final_settings.get("routing_strategy"),
ctx_fallbacks or "none",
)
except Exception as e:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
@classmethod
def _build_context_fallback_groups(
cls, model_list: list[dict]
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
"""Create an ``auto-large`` model group for context-window fallbacks.
Uses ``litellm.get_model_info`` to discover the context window of each
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
window are duplicated into an ``auto-large`` group. The returned
fallback config tells the Router: on ``ContextWindowExceededError`` for
``auto``, retry with ``auto-large``.
Returns:
(full_model_list, context_window_fallbacks) ``full_model_list``
contains the original entries plus any ``auto-large`` duplicates.
``context_window_fallbacks`` is ``None`` when every deployment has
the same context size (no useful fallback).
"""
from litellm import get_model_info
ctx_map: dict[str, int] = {}
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0:
ctx_map[base_model] = ctx
except Exception:
continue
if not ctx_map:
return model_list, None
min_ctx = min(ctx_map.values())
large_deployments: list[dict] = []
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
if ctx_map.get(base_model, 0) > min_ctx:
dup = {**dep, "model_name": "auto-large"}
large_deployments.append(dup)
if not large_deployments:
return model_list, None
logger.info(
"Context-window fallback: %d large-context deployments "
"(min_ctx=%d) added to 'auto-large' group",
len(large_deployments),
min_ctx,
)
return model_list + large_deployments, [{"auto": ["auto-large"]}]
@classmethod
def _config_to_deployment(cls, config: dict) -> dict | None:
"""
@ -228,12 +323,62 @@ class LLMRouterService:
return len(instance._model_list)
_cached_context_profile: dict | None = None
_cached_context_profile_computed: bool = False
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
def _get_cached_context_profile(router: Router) -> dict | None:
"""Compute and cache the min context profile across all router deployments.
Called once on first ChatLiteLLMRouter creation; subsequent calls return
the cached value. This avoids calling litellm.get_model_info() for every
deployment on every request.
"""
global _cached_context_profile, _cached_context_profile_computed
if _cached_context_profile_computed:
return _cached_context_profile
from litellm import get_model_info
min_ctx: int | None = None
for deployment in router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
min_ctx = ctx
except Exception:
continue
if min_ctx is not None:
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
_cached_context_profile = {"max_input_tokens": min_ctx}
else:
_cached_context_profile = None
_cached_context_profile_computed = True
return _cached_context_profile
class ChatLiteLLMRouter(BaseChatModel):
"""
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
making it a drop-in replacement for auto-mode routing.
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
directly instances without bound tools are cached per streaming flag to
avoid per-request re-initialization overhead and memory growth.
"""
# Use model_config for Pydantic v2 compatibility
@ -255,17 +400,8 @@ class ChatLiteLLMRouter(BaseChatModel):
tool_choice: str | dict | None = None,
**kwargs,
):
"""
Initialize the ChatLiteLLMRouter.
Args:
router: LiteLLM Router instance. If None, uses the global singleton.
bound_tools: Pre-bound tools for tool calling
tool_choice: Tool choice configuration
"""
try:
super().__init__(**kwargs)
# Store router and tools as private attributes
resolved_router = router or LLMRouterService.get_router()
object.__setattr__(self, "_router", resolved_router)
object.__setattr__(self, "_bound_tools", bound_tools)
@ -274,8 +410,16 @@ class ChatLiteLLMRouter(BaseChatModel):
raise ValueError(
"LLM Router not initialized. Call LLMRouterService.initialize() first."
)
logger.info(
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
computed_profile = _get_cached_context_profile(self._router)
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)
logger.debug(
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
LLMRouterService.get_model_count(),
self.streaming,
bound_tools is not None,
)
except Exception as e:
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
@ -349,6 +493,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
@ -359,12 +507,36 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router completion
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
try:
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
elapsed = time.perf_counter() - t0
perf.info(
"[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
msg_count,
len(self._bound_tools) if self._bound_tools else 0,
elapsed,
)
# Convert response to ChatResult with potential tool calls
@ -386,6 +558,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
@ -396,12 +572,36 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router async completion
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
try:
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
elapsed = time.perf_counter() - t0
perf.info(
"[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
msg_count,
len(self._bound_tools) if self._bound_tools else 0,
elapsed,
)
# Convert response to ChatResult with potential tool calls
@ -432,14 +632,20 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router completion with streaming
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
try:
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks
for chunk in response:
@ -462,6 +668,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
formatted_messages = self._convert_messages(messages)
# Add tools if bound
@ -471,23 +681,61 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router async completion with streaming
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
try:
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
t_first_chunk = time.perf_counter()
perf.info(
"[llm_router] _astream connection established msgs=%d in %.3fs",
msg_count,
t_first_chunk - t0,
)
# Yield chunks asynchronously
chunk_count = 0
first_chunk_logged = False
async for chunk in response:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
chunk_msg = self._convert_delta_to_chunk(delta)
if chunk_msg:
chunk_count += 1
if not first_chunk_logged:
perf.info(
"[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
time.perf_counter() - t_first_chunk,
time.perf_counter() - t0,
)
first_chunk_logged = True
yield ChatGenerationChunk(message=chunk_msg)
perf.info(
"[llm_router] _astream completed chunks=%d total=%.3fs",
chunk_count,
time.perf_counter() - t0,
)
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
"""Convert LangChain messages to OpenAI format."""
from langchain_core.messages import (
@ -602,19 +850,28 @@ class ChatLiteLLMRouter(BaseChatModel):
return None
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
"""
Get a ChatLiteLLMRouter instance for auto mode.
def get_auto_mode_llm(
*,
streaming: bool = True,
) -> ChatLiteLLMRouter | None:
"""Return a cached ChatLiteLLMRouter for auto mode.
Returns:
ChatLiteLLMRouter instance or None if router not initialized
Base (no tools) instances are cached per ``streaming`` flag so we
avoid re-constructing them on every request. ``bind_tools()`` still
returns a fresh instance because bound tools differ per agent.
"""
if not LLMRouterService.is_initialized():
logger.warning("LLM Router not initialized for auto mode")
return None
cached = _router_instance_cache.get(streaming)
if cached is not None:
return cached
try:
return ChatLiteLLMRouter()
instance = ChatLiteLLMRouter(streaming=streaming)
_router_instance_cache[streaming] = instance
return instance
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None

View file

@ -12,12 +12,20 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
# Memory controls: prevent unbounded internal accumulation
litellm.telemetry = False
litellm.cache = None
litellm.success_callback = []
litellm.failure_callback = []
litellm.input_callback = []
logger = logging.getLogger(__name__)
@ -221,7 +229,7 @@ async def get_search_space_llm_instance(
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
return ChatLiteLLMRouter(disable_streaming=disable_streaming)
return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None

View file

@ -4,11 +4,11 @@ from datetime import datetime
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import Chunk, Document
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
)
@ -127,10 +127,8 @@ class NotionKBSyncService:
logger.debug(f"Generated summary length: {len(summary_content)} chars")
else:
logger.warning("No LLM configured - using fallback summary")
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content[:500]}..."
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content}"
summary_embedding = embed_text(summary_content)
logger.debug(f"Deleting old chunks for document {document_id}")
await self.db_session.execute(

View file

@ -0,0 +1,53 @@
"""Task dispatcher abstraction for background document processing.
Decouples the upload endpoint from Celery so tests can swap in a
synchronous (inline) implementation that needs only PostgreSQL.
"""
from __future__ import annotations
from typing import Protocol
class TaskDispatcher(Protocol):
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
) -> None: ...
class CeleryTaskDispatcher:
"""Production dispatcher — fires Celery tasks via Redis broker."""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
)
process_file_upload_with_document_task.delay(
document_id=document_id,
temp_path=temp_path,
filename=filename,
search_space_id=search_space_id,
user_id=user_id,
should_summarize=should_summarize,
)
async def get_task_dispatcher() -> TaskDispatcher:
return CeleryTaskDispatcher()

View file

@ -1 +1,28 @@
"""Celery tasks package."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
_celery_engine = None
_celery_session_maker = None
def get_celery_session_maker() -> async_sessionmaker:
"""Return a shared async session maker for Celery tasks.
A single NullPool engine is created per worker process and reused
across all task invocations to avoid leaking engine objects.
"""
global _celery_engine, _celery_session_maker
if _celery_session_maker is None:
_celery_engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
_celery_session_maker = async_sessionmaker(
_celery_engine, expire_on_commit=False
)
return _celery_session_maker

View file

@ -3,11 +3,8 @@
import logging
import traceback
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="index_slack_messages", bind=True)
def index_slack_messages_task(
self,

View file

@ -2,35 +2,20 @@
import logging
from sqlalchemy import delete, select
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
generate_document_summary,
)
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="reindex_document", bind=True)
def reindex_document_task(self, document_id: int, user_id: str):
"""
@ -54,7 +39,6 @@ def reindex_document_task(self, document_id: int, user_id: str):
async def _reindex_document(document_id: int, user_id: str):
"""Async function to reindex a document."""
async with get_celery_session_maker()() as session:
# First, get the document to get search_space_id for logging
result = await session.execute(
select(Document)
.options(selectinload(Document.chunks))
@ -66,10 +50,8 @@ async def _reindex_document(document_id: int, user_id: str):
logger.error(f"Document {document_id} not found")
return
# Initialize task logger
task_logger = TaskLoggingService(session, document.search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="document_reindex",
source="editor",
@ -83,10 +65,7 @@ async def _reindex_document(document_id: int, user_id: str):
)
try:
# Read markdown directly from source_markdown
markdown_content = document.source_markdown
if not markdown_content:
if not document.source_markdown:
await task_logger.log_task_failure(
log_entry,
f"Document {document_id} has no source_markdown to reindex",
@ -97,51 +76,17 @@ async def _reindex_document(document_id: int, user_id: str):
logger.info(f"Reindexing document {document_id} ({document.title})")
# 1. Delete old chunks explicitly
from app.db import Chunk
await session.execute(delete(Chunk).where(Chunk.document_id == document_id))
await session.flush() # Ensure old chunks are deleted
# 2. Create new chunks from source_markdown
new_chunks = await create_document_chunks(markdown_content)
# 3. Add new chunks to session
for chunk in new_chunks:
chunk.document_id = document_id
session.add(chunk)
logger.info(f"Created {len(new_chunks)} chunks for document {document_id}")
# 4. Regenerate summary
user_llm = await get_user_long_context_llm(
session, user_id, document.search_space_id
)
document_metadata = {
"title": document.title,
"document_type": document.document_type.value,
}
adapter = UploadDocumentAdapter(session)
await adapter.reindex(document=document, llm=user_llm)
summary_content, summary_embedding = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
# 5. Update document
document.content = summary_content
document.embedding = summary_embedding
document.content_needs_reindexing = False
await session.commit()
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully reindexed document: {document.title}",
{
"chunks_created": len(new_chunks),
"document_id": document_id,
},
{"document_id": document_id},
)
logger.info(f"Successfully reindexed document {document_id}")

View file

@ -5,13 +5,11 @@ import logging
import os
from uuid import UUID
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.document_processors import (
add_extension_received_document,
add_youtube_video_document,
@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int):
pass # Normal cancellation when task completes
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="process_extension_document", bind=True)
def process_extension_document_task(
self, individual_document_dict, search_space_id: int, user_id: str
@ -626,6 +610,7 @@ def process_file_upload_with_document_task(
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
):
"""
Celery task to process uploaded file with existing pending document.
@ -640,6 +625,7 @@ def process_file_upload_with_document_task(
filename: Original filename
search_space_id: ID of the search space
user_id: ID of the user
should_summarize: Whether to generate an LLM summary
"""
import traceback
@ -674,7 +660,12 @@ def process_file_upload_with_document_task(
try:
loop.run_until_complete(
_process_file_with_document(
document_id, temp_path, filename, search_space_id, user_id
document_id,
temp_path,
filename,
search_space_id,
user_id,
should_summarize=should_summarize,
)
)
logger.info(
@ -710,6 +701,7 @@ async def _process_file_with_document(
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
):
"""
Process file and update existing pending document status.
@ -811,6 +803,7 @@ async def _process_file_with_document(
task_logger=task_logger,
log_entry=log_entry,
notification=notification,
should_summarize=should_summarize,
)
# Update notification on success

View file

@ -5,14 +5,13 @@ import logging
import sys
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast, PodcastStatus
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -25,20 +24,6 @@ if sys.platform.startswith("win"):
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================

View file

@ -3,28 +3,16 @@
import logging
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.tasks.celery_tasks import get_celery_session_maker
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="check_periodic_schedules")
def check_periodic_schedules_task():
"""

View file

@ -29,20 +29,17 @@ from datetime import UTC, datetime
import redis
from sqlalchemy import and_, or_, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
# Redis client for checking heartbeats
_redis_client: redis.Redis | None = None
# Error messages shown to users when tasks are interrupted
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry."
@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str:
return f"indexing:heartbeat:{notification_id}"
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="cleanup_stale_indexing_notifications")
def cleanup_stale_indexing_notifications_task():
"""

View file

@ -9,17 +9,23 @@ Supports loading LLM configurations from:
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
"""
import asyncio
import contextlib
import gc
import json
import logging
import re
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import logging
import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
@ -30,7 +36,21 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
)
from app.db import (
ChatVisibility,
Document,
NewChatMessage,
NewChatThread,
Report,
SearchSourceConnectorType,
SurfsenseDocsDocument,
async_session_maker,
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.chat_session_state_service import (
clear_ai_responding,
@ -39,6 +59,11 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
@ -187,6 +212,7 @@ class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
async def _stream_agent_events(
@ -404,6 +430,21 @@ async def _stream_agent_events(
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "execute":
cmd = (
tool_input.get("command", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_cmd = cmd[:80] + ("" if len(cmd) > 80 else "")
last_active_step_title = "Running command"
last_active_step_items = [f"$ {display_cmd}"]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Running command",
status="in_progress",
items=last_active_step_items,
)
else:
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
last_active_step_items = []
@ -620,6 +661,32 @@ async def _stream_agent_events(
status="completed",
items=completed_items,
)
elif tool_name == "execute":
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
exit_code_val = int(m.group(1)) if m else None
if exit_code_val is not None and exit_code_val == 0:
completed_items = [
*last_active_step_items,
"Completed successfully",
]
elif exit_code_val is not None:
completed_items = [
*last_active_step_items,
f"Exit code: {exit_code_val}",
]
else:
completed_items = [*last_active_step_items, "Finished"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Running command",
status="completed",
items=completed_items,
)
elif tool_name == "ls":
if isinstance(tool_output, dict):
ls_output = tool_output.get("result", "")
@ -813,6 +880,36 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else {"result": tool_output},
)
elif tool_name == "execute":
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
exit_code: int | None = None
output_text = raw_text
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
if m:
exit_code = int(m.group(1))
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
output_text = om.group(1) if om else ""
thread_id_str = config.get("configurable", {}).get("thread_id", "")
for sf_match in re.finditer(
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
):
fpath = sf_match.group(1).strip()
if fpath and fpath not in result.sandbox_files:
result.sandbox_files.append(fpath)
yield streaming_service.format_tool_output_available(
tool_call_id,
{
"exit_code": exit_code,
"output": output_text,
"thread_id": thread_id_str,
},
)
else:
yield streaming_service.format_tool_output_available(
tool_call_id,
@ -881,11 +978,42 @@ async def _stream_agent_events(
yield streaming_service.format_interrupt_request(result.interrupt_value)
def _try_persist_and_delete_sandbox(
thread_id: int,
sandbox_files: list[str],
) -> None:
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if not is_sandbox_enabled():
return
async def _run() -> None:
try:
await persist_and_delete_sandbox(thread_id, sandbox_files)
except Exception:
logging.getLogger(__name__).warning(
"persist_and_delete_sandbox failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_run())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
@ -901,11 +1029,13 @@ async def stream_new_chat(
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
The function creates and manages its own database session to guarantee proper
cleanup even when Starlette's middleware cancels the task on client disconnect.
Args:
user_query: The user's query
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
@ -917,7 +1047,11 @@ async def stream_new_chat(
str: SSE formatted response strings
"""
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START")
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
if user_id:
@ -925,6 +1059,7 @@ async def stream_new_chat(
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
if llm_config_id >= 0:
# Positive ID: Load from NewLLMConfig database table
agent_config = await load_agent_config(
@ -955,6 +1090,11 @@ async def stream_new_chat(
llm = create_chat_litellm_from_config(llm_config)
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0,
llm_config_id,
)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
@ -962,22 +1102,45 @@ async def stream_new_chat(
return
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
# Get Firecrawl API key from webcrawler connector if configured
from app.db import SearchSourceConnectorType
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
@ -989,20 +1152,22 @@ async def stream_new_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
# Build input with message history
langchain_messages = []
_t0 = time.perf_counter()
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=visibility
)
# Clear the flag so we don't bootstrap again on next message
from app.db import NewChatThread
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
@ -1014,11 +1179,9 @@ async def stream_new_chat(
# Fetch mentioned documents if any (with chunks for proper citations)
mentioned_documents: list[Document] = []
if mentioned_document_ids:
from sqlalchemy.orm import selectinload as doc_selectinload
result = await session.execute(
select(Document)
.options(doc_selectinload(Document.chunks))
.options(selectinload(Document.chunks))
.filter(
Document.id.in_(mentioned_document_ids),
Document.search_space_id == search_space_id,
@ -1029,8 +1192,6 @@ async def stream_new_chat(
# Fetch mentioned SurfSense docs if any
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
if mentioned_surfsense_doc_ids:
from sqlalchemy.orm import selectinload
result = await session.execute(
select(SurfsenseDocsDocument)
.options(selectinload(SurfsenseDocsDocument.chunks))
@ -1114,6 +1275,11 @@ async def stream_new_chat(
"search_space_id": search_space_id,
}
_perf_log.info(
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
time.perf_counter() - _t0,
)
# All pre-streaming DB reads are done. Commit to release the
# transaction and its ACCESS SHARE locks so we don't block DDL
# (e.g. migrations) for the entire duration of LLM streaming.
@ -1121,6 +1287,18 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted the data
# we need. This prevents them from accumulating in memory for the
# entire duration of LLM streaming (which can be several minutes).
session.expunge_all()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
# Configure LangGraph with thread_id for memory
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
configurable = {"thread_id": str(chat_id)}
@ -1182,7 +1360,14 @@ async def stream_new_chat(
items=initial_items,
)
stream_result = StreamResult()
# These ORM objects (with eagerly-loaded chunks) can be very large.
# They're only needed to build context strings already copied into
# final_query / langchain_messages — release them before streaming.
del mentioned_documents, mentioned_surfsense_docs, recent_reports
del langchain_messages, final_query
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
@ -1194,8 +1379,24 @@ async def stream_new_chat(
initial_step_title=initial_title,
initial_step_items=initial_items,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
log_system_snapshot("stream_new_chat_END")
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
@ -1204,12 +1405,6 @@ async def stream_new_chat(
accumulated_text = stream_result.accumulated_text
# Generate LLM title for new chats after first response
# Check if this is the first assistant response by counting existing assistant messages
from sqlalchemy import func
from app.db import NewChatMessage, NewChatThread
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
NewChatMessage.thread_id == chat_id,
@ -1280,39 +1475,72 @@ async def stream_new_chat(
yield streaming_service.format_done()
finally:
# Clear AI responding state for live collaboration.
# The original session may be broken (client disconnect / CancelledError
# can corrupt the underlying DB connection), so we try a rollback first
# and fall back to a fresh session if the original is unusable.
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
# Shield the ENTIRE async cleanup from anyio cancel-scope
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
# groups; on client disconnect, it cancels the scope with
# level-triggered cancellation — every unshielded `await` inside
# the cancelled scope raises CancelledError immediately. Without
# this shield the very first `await` (session.rollback) would
# raise CancelledError, `except Exception` wouldn't catch it
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None
input_state = stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot("stream_new_chat_END")
async def stream_resume_chat(
chat_id: int,
search_space_id: int,
decisions: list[dict],
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
) -> AsyncGenerator[str, None]:
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
if llm_config_id >= 0:
agent_config = await load_agent_config(
session=session,
@ -1336,26 +1564,54 @@ async def stream_resume_chat(
return
llm = create_chat_litellm_from_config(llm_config)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
yield streaming_service.format_done()
return
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
from app.db import SearchSourceConnectorType
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
@ -1367,10 +1623,21 @@ async def stream_resume_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
session.expunge_all()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
from langgraph.types import Command
@ -1382,7 +1649,8 @@ async def stream_resume_chat(
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
stream_result = StreamResult()
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
@ -1391,7 +1659,20 @@ async def stream_resume_chat(
result=stream_result,
step_prefix="thinking-resume",
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
@ -1414,14 +1695,37 @@ async def stream_resume_chat(
yield streaming_service.format_done()
finally:
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None
stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot("stream_resume_chat_END")

View file

@ -12,13 +12,13 @@ from collections.abc import Awaitable, Callable
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.airtable_history import AirtableHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -399,7 +399,7 @@ async def index_airtable_records(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"record_id": item["record_id"],
"created_time": item["record"].get("CREATED_TIME()", ""),
@ -415,11 +415,8 @@ async def index_airtable_records(
document_metadata_for_summary,
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Airtable Record: {item['record_id']}\n\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Airtable Record: {item['record_id']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])

View file

@ -13,13 +13,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.bookstack_connector import BookStackConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -403,7 +403,7 @@ async def index_bookstack_pages(
"connector_id": connector_id,
}
if user_llm:
if user_llm and connector.enable_summary:
summary_metadata = {
"page_name": item["page_name"],
"page_id": item["page_id"],
@ -418,17 +418,8 @@ async def index_bookstack_pages(
item["full_content"], user_llm, summary_metadata
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n"
if item["page_content"]:
# Take first 1000 characters of content for summary
content_preview = item["page_content"][:1000]
if len(item["page_content"]) > 1000:
content_preview += "..."
summary_content += f"Content Preview: {content_preview}\n\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content
chunks = await create_document_chunks(item["full_content"])

View file

@ -14,13 +14,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.clickup_history import ClickUpHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -398,7 +398,7 @@ async def index_clickup_tasks(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"task_id": item["task_id"],
"task_name": item["task_name"],
@ -418,9 +418,7 @@ async def index_clickup_tasks(
)
else:
summary_content = item["task_content"]
summary_embedding = config.embedding_model_instance.embed(
item["task_content"]
)
summary_embedding = embed_text(item["task_content"])
chunks = await create_document_chunks(item["task_content"])

View file

@ -14,13 +14,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -378,7 +378,7 @@ async def index_confluence_pages(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
@ -394,18 +394,8 @@ async def index_confluence_pages(
item["full_content"], user_llm, document_metadata
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n"
if item["page_content"]:
# Take first 1000 characters of content for summary
content_preview = item["page_content"][:1000]
if len(item["page_content"]) > 1000:
content_preview += "..."
summary_content += f"Content Preview: {content_preview}\n\n"
summary_content += f"Comments: {item['comment_count']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content with comments
chunks = await create_document_chunks(item["full_content"])

View file

@ -23,6 +23,7 @@ from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnector
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_unique_identifier_hash,
)
@ -669,9 +670,7 @@ async def index_discord_messages(
# Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = config.embedding_model_instance.embed(
item["combined_document_string"]
)
doc_embedding = embed_text(item["combined_document_string"])
# Update document to READY with actual content
document.title = f"{item['guild_name']}#{item['channel_name']}"

View file

@ -16,13 +16,13 @@ from datetime import UTC, datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.github_connector import GitHubConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -367,7 +367,7 @@ async def index_github_repos(
"estimated_tokens": digest.estimated_tokens,
}
if user_llm:
if user_llm and connector.enable_summary:
# Prepare content for summarization
summary_content = digest.full_digest
if len(summary_content) > MAX_DIGEST_CHARS:
@ -381,15 +381,12 @@ async def index_github_repos(
summary_content, user_llm, document_metadata_for_summary
)
else:
# Fallback to simple summary if no LLM configured
summary_text = (
f"# GitHub Repository: {repo_full_name}\n\n"
f"## Summary\n{digest.summary}\n\n"
f"## File Structure\n{digest.tree[:3000]}"
)
summary_embedding = config.embedding_model_instance.embed(
summary_text
f"## File Structure\n{digest.tree}"
)
summary_embedding = embed_text(summary_text)
# Chunk the full digest content for granular search
try:
@ -551,7 +548,7 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list:
chunks.append(
Chunk(
content=chunk_text,
embedding=config.embedding_model_instance.embed(chunk_text),
embedding=embed_text(chunk_text),
)
)

View file

@ -20,6 +20,7 @@ from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -489,7 +490,7 @@ async def index_google_calendar_events(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
@ -507,22 +508,8 @@ async def index_google_calendar_events(
item["event_markdown"], user_llm, document_metadata_for_summary
)
else:
summary_content = (
f"Google Calendar Event: {item['event_summary']}\n\n"
)
summary_content += f"Calendar: {item['calendar_id']}\n"
summary_content += f"Start: {item['start_time']}\n"
summary_content += f"End: {item['end_time']}\n"
if item["location"]:
summary_content += f"Location: {item['location']}\n"
if item["description"]:
desc_preview = item["description"][:1000]
if len(item["description"]) > 1000:
desc_preview += "..."
summary_content += f"Description: {desc_preview}\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["event_markdown"])

Some files were not shown because too many files have changed in this diff Show more