mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
Merge pull request #1 from AnishSarkar22/feat/test-ci
testing backend test CI workflow
This commit is contained in:
commit
f53759f0e1
220 changed files with 16886 additions and 6950 deletions
112
.cursor/skills/tdd/SKILL.md
Normal file
112
.cursor/skills/tdd/SKILL.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
---
|
||||
name: tdd
|
||||
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
|
||||
---
|
||||
|
||||
---
|
||||
name: tdd
|
||||
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
|
||||
---
|
||||
|
||||
# Test-Driven Development
|
||||
|
||||
## Philosophy
|
||||
|
||||
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
|
||||
|
||||
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
|
||||
|
||||
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
|
||||
|
||||
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
|
||||
|
||||
## Anti-Pattern: Horizontal Slices
|
||||
|
||||
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
|
||||
|
||||
This produces **crap tests**:
|
||||
|
||||
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
|
||||
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
|
||||
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
|
||||
- You outrun your headlights, committing to test structure before understanding the implementation
|
||||
|
||||
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
|
||||
|
||||
```
|
||||
WRONG (horizontal):
|
||||
RED: test1, test2, test3, test4, test5
|
||||
GREEN: impl1, impl2, impl3, impl4, impl5
|
||||
|
||||
RIGHT (vertical):
|
||||
RED→GREEN: test1→impl1
|
||||
RED→GREEN: test2→impl2
|
||||
RED→GREEN: test3→impl3
|
||||
...
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Planning
|
||||
|
||||
Before writing any code:
|
||||
|
||||
- [ ] Confirm with user what interface changes are needed
|
||||
- [ ] Confirm with user which behaviors to test (prioritize)
|
||||
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
|
||||
- [ ] Design interfaces for [testability](interface-design.md)
|
||||
- [ ] List the behaviors to test (not implementation steps)
|
||||
- [ ] Get user approval on the plan
|
||||
|
||||
Ask: "What should the public interface look like? Which behaviors are most important to test?"
|
||||
|
||||
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
|
||||
|
||||
### 2. Tracer Bullet
|
||||
|
||||
Write ONE test that confirms ONE thing about the system:
|
||||
|
||||
```
|
||||
RED: Write test for first behavior → test fails
|
||||
GREEN: Write minimal code to pass → test passes
|
||||
```
|
||||
|
||||
This is your tracer bullet - proves the path works end-to-end.
|
||||
|
||||
### 3. Incremental Loop
|
||||
|
||||
For each remaining behavior:
|
||||
|
||||
```
|
||||
RED: Write next test → fails
|
||||
GREEN: Minimal code to pass → passes
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- One test at a time
|
||||
- Only enough code to pass current test
|
||||
- Don't anticipate future tests
|
||||
- Keep tests focused on observable behavior
|
||||
|
||||
### 4. Refactor
|
||||
|
||||
After all tests pass, look for [refactor candidates](refactoring.md):
|
||||
|
||||
- [ ] Extract duplication
|
||||
- [ ] Deepen modules (move complexity behind simple interfaces)
|
||||
- [ ] Apply SOLID principles where natural
|
||||
- [ ] Consider what new code reveals about existing code
|
||||
- [ ] Run tests after each refactor step
|
||||
|
||||
**Never refactor while RED.** Get to GREEN first.
|
||||
|
||||
## Checklist Per Cycle
|
||||
|
||||
```
|
||||
[ ] Test describes behavior, not implementation
|
||||
[ ] Test uses public interface only
|
||||
[ ] Test would survive internal refactor
|
||||
[ ] Code is minimal for this test
|
||||
[ ] No speculative features added
|
||||
```
|
||||
33
.cursor/skills/tdd/deep-modules.md
Normal file
33
.cursor/skills/tdd/deep-modules.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Deep Modules
|
||||
|
||||
From "A Philosophy of Software Design":
|
||||
|
||||
**Deep module** = small interface + lots of implementation
|
||||
|
||||
```
|
||||
┌─────────────────────┐
|
||||
│ Small Interface │ ← Few methods, simple params
|
||||
├─────────────────────┤
|
||||
│ │
|
||||
│ │
|
||||
│ Deep Implementation│ ← Complex logic hidden
|
||||
│ │
|
||||
│ │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
**Shallow module** = large interface + little implementation (avoid)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────┐
|
||||
│ Large Interface │ ← Many methods, complex params
|
||||
├─────────────────────────────────┤
|
||||
│ Thin Implementation │ ← Just passes through
|
||||
└─────────────────────────────────┘
|
||||
```
|
||||
|
||||
When designing interfaces, ask:
|
||||
|
||||
- Can I reduce the number of methods?
|
||||
- Can I simplify the parameters?
|
||||
- Can I hide more complexity inside?
|
||||
33
.cursor/skills/tdd/interface-design.md
Normal file
33
.cursor/skills/tdd/interface-design.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Interface Design for Testability
|
||||
|
||||
Good interfaces make testing natural:
|
||||
|
||||
1. **Accept dependencies, don't create them**
|
||||
```python
|
||||
# Testable
|
||||
def process_order(order, payment_gateway):
|
||||
pass
|
||||
|
||||
# Hard to test
|
||||
def process_order(order):
|
||||
gateway = StripeGateway()
|
||||
|
||||
```
|
||||
|
||||
|
||||
2. **Return results, don't produce side effects**
|
||||
```python
|
||||
# Testable
|
||||
def calculate_discount(cart) -> float:
|
||||
return discount
|
||||
|
||||
# Hard to test
|
||||
def apply_discount(cart) -> None:
|
||||
cart.total -= discount
|
||||
|
||||
```
|
||||
|
||||
|
||||
3. **Small surface area**
|
||||
* Fewer methods = fewer tests needed
|
||||
* Fewer params = simpler test setup
|
||||
69
.cursor/skills/tdd/mocking.md
Normal file
69
.cursor/skills/tdd/mocking.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
# When to Mock
|
||||
|
||||
Mock at **system boundaries** only:
|
||||
|
||||
* External APIs (payment, email, etc.)
|
||||
* Databases (sometimes - prefer test DB)
|
||||
* Time/randomness
|
||||
* File system (sometimes)
|
||||
|
||||
Don't mock:
|
||||
|
||||
* Your own classes/modules
|
||||
* Internal collaborators
|
||||
* Anything you control
|
||||
|
||||
## Designing for Mockability
|
||||
|
||||
At system boundaries, design interfaces that are easy to mock:
|
||||
|
||||
**1. Use dependency injection**
|
||||
|
||||
Pass external dependencies in rather than creating them internally:
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Easy to mock
|
||||
def process_payment(order, payment_client):
|
||||
return payment_client.charge(order.total)
|
||||
|
||||
# Hard to mock
|
||||
def process_payment(order):
|
||||
client = StripeClient(os.getenv("STRIPE_KEY"))
|
||||
return client.charge(order.total)
|
||||
|
||||
```
|
||||
|
||||
**2. Prefer SDK-style interfaces over generic fetchers**
|
||||
|
||||
Create specific functions for each external operation instead of one generic function with conditional logic:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# GOOD: Each function is independently mockable
|
||||
class UserAPI:
|
||||
def get_user(self, user_id):
|
||||
return requests.get(f"/users/{user_id}")
|
||||
|
||||
def get_orders(self, user_id):
|
||||
return requests.get(f"/users/{user_id}/orders")
|
||||
|
||||
def create_order(self, data):
|
||||
return requests.post("/orders", json=data)
|
||||
|
||||
# BAD: Mocking requires conditional logic inside the mock
|
||||
class GenericAPI:
|
||||
def fetch(self, endpoint, method="GET", data=None):
|
||||
return requests.request(method, endpoint, json=data)
|
||||
|
||||
```
|
||||
|
||||
The SDK approach means:
|
||||
|
||||
* Each mock returns one specific shape
|
||||
* No conditional logic in test setup
|
||||
* Easier to see which endpoints a test exercises
|
||||
* Type safety per endpoint
|
||||
10
.cursor/skills/tdd/refactoring.md
Normal file
10
.cursor/skills/tdd/refactoring.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Refactor Candidates
|
||||
|
||||
After TDD cycle, look for:
|
||||
|
||||
- **Duplication** → Extract function/class
|
||||
- **Long methods** → Break into private helpers (keep tests on public interface)
|
||||
- **Shallow modules** → Combine or deepen
|
||||
- **Feature envy** → Move logic to where data lives
|
||||
- **Primitive obsession** → Introduce value objects
|
||||
- **Existing code** the new code reveals as problematic
|
||||
60
.cursor/skills/tdd/tests.md
Normal file
60
.cursor/skills/tdd/tests.md
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Good and Bad Tests
|
||||
|
||||
## Good Tests
|
||||
|
||||
**Integration-style**: Test through real interfaces, not mocks of internal parts.
|
||||
|
||||
```python
|
||||
# GOOD: Tests observable behavior
|
||||
def test_user_can_checkout_with_valid_cart():
|
||||
cart = create_cart()
|
||||
cart.add(product)
|
||||
result = checkout(cart, payment_method)
|
||||
assert result.status == "confirmed"
|
||||
|
||||
```
|
||||
|
||||
Characteristics:
|
||||
|
||||
* Tests behavior users/callers care about
|
||||
* Uses public API only
|
||||
* Survives internal refactors
|
||||
* Describes WHAT, not HOW
|
||||
* One logical assertion per test
|
||||
|
||||
## Bad Tests
|
||||
|
||||
**Implementation-detail tests**: Coupled to internal structure.
|
||||
|
||||
```python
|
||||
# BAD: Tests implementation details
|
||||
def test_checkout_calls_payment_service_process():
|
||||
mock_payment = MagicMock()
|
||||
checkout(cart, mock_payment)
|
||||
mock_payment.process.assert_called_with(cart.total)
|
||||
|
||||
```
|
||||
|
||||
Red flags:
|
||||
|
||||
* Mocking internal collaborators
|
||||
* Testing private methods
|
||||
* Asserting on call counts/order
|
||||
* Test breaks when refactoring without behavior change
|
||||
* Test name describes HOW not WHAT
|
||||
* Verifying through external means instead of interface
|
||||
|
||||
```python
|
||||
# BAD: Bypasses interface to verify
|
||||
def test_create_user_saves_to_database():
|
||||
create_user({"name": "Alice"})
|
||||
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
|
||||
assert row is not None
|
||||
|
||||
# GOOD: Verifies through interface
|
||||
def test_create_user_makes_user_retrievable():
|
||||
user = create_user({"name": "Alice"})
|
||||
retrieved = get_user(user.id)
|
||||
assert retrieved.name == "Alice"
|
||||
|
||||
```
|
||||
41
.env.example
41
.env.example
|
|
@ -1,41 +0,0 @@
|
|||
# Docker Specific Env's Only - Can skip if needed
|
||||
|
||||
# Celery Config
|
||||
REDIS_PORT=6379
|
||||
FLOWER_PORT=5555
|
||||
|
||||
# Frontend Configuration
|
||||
FRONTEND_PORT=3000
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 (Default: http://localhost:8000)
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE (Default: LOCAL)
|
||||
NEXT_PUBLIC_ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING (Default: DOCLING)
|
||||
# Backend Configuration
|
||||
BACKEND_PORT=8000
|
||||
# Auth type for backend login flow (Default: LOCAL)
|
||||
# Set to GOOGLE if using Google OAuth
|
||||
AUTH_TYPE=LOCAL
|
||||
# Frontend URL used by backend for CORS allowed origins and OAuth redirects
|
||||
# Must match the URL your browser uses to access the frontend
|
||||
NEXT_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
# Database Configuration
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=surfsense
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# Electric-SQL Configuration
|
||||
ELECTRIC_PORT=5133
|
||||
# PostgreSQL host for Electric connection
|
||||
# - 'db' for Docker PostgreSQL (service name in docker-compose)
|
||||
# - 'host.docker.internal' for local PostgreSQL (recommended when Electric runs in Docker)
|
||||
# Note: host.docker.internal works on Docker Desktop (Mac/Windows) and can be enabled on Linux
|
||||
POSTGRES_HOST=db
|
||||
ELECTRIC_DB_USER=electric
|
||||
ELECTRIC_DB_PASSWORD=electric_password
|
||||
NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
|
||||
|
||||
# pgAdmin Configuration
|
||||
PGADMIN_PORT=5050
|
||||
PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
PGADMIN_DEFAULT_PASSWORD=surfsense
|
||||
161
.github/workflows/backend-tests.yml
vendored
Normal file
161
.github/workflows/backend-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
name: Backend Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main, dev]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
name: Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
env:
|
||||
EMBEDDING_MODEL: sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check if backend files changed
|
||||
id: backend-changes
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
backend:
|
||||
- 'surfsense_backend/**'
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Cache dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
surfsense_backend/.venv
|
||||
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
|
||||
restore-keys: |
|
||||
python-deps-
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/huggingface
|
||||
key: hf-models-${{ env.EMBEDDING_MODEL }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
working-directory: surfsense_backend
|
||||
run: uv sync
|
||||
|
||||
- name: Run unit tests
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
working-directory: surfsense_backend
|
||||
run: uv run pytest -m unit
|
||||
|
||||
integration-tests:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
env:
|
||||
EMBEDDING_MODEL: sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: surfsense_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres -d surfsense_test"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check if backend files changed
|
||||
id: backend-changes
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
backend:
|
||||
- 'surfsense_backend/**'
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Cache dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
surfsense_backend/.venv
|
||||
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
|
||||
restore-keys: |
|
||||
python-deps-
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/huggingface
|
||||
key: hf-models-${{ env.EMBEDDING_MODEL }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
working-directory: surfsense_backend
|
||||
run: uv sync
|
||||
|
||||
- name: Run integration tests
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
working-directory: surfsense_backend
|
||||
env:
|
||||
TEST_DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test
|
||||
SECRET_KEY: ci-test-secret-key-not-for-production
|
||||
ETL_SERVICE: DOCLING
|
||||
run: uv run pytest -m integration
|
||||
|
||||
test-gate:
|
||||
name: Test Gate
|
||||
runs-on: ubuntu-latest
|
||||
needs: [unit-tests, integration-tests]
|
||||
if: always()
|
||||
|
||||
steps:
|
||||
- name: Check all test jobs
|
||||
run: |
|
||||
if [[ "${{ needs.unit-tests.result }}" == "failure" ||
|
||||
"${{ needs.integration-tests.result }}" == "failure" ]]; then
|
||||
echo "Backend tests failed"
|
||||
exit 1
|
||||
else
|
||||
echo "All backend tests passed"
|
||||
fi
|
||||
170
.github/workflows/docker_build.yaml
vendored
170
.github/workflows/docker_build.yaml
vendored
|
|
@ -1,6 +1,13 @@
|
|||
name: Build and Push Docker Image
|
||||
name: Build and Push Docker Images
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths:
|
||||
- 'surfsense_backend/**'
|
||||
- 'surfsense_web/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
|
|
@ -8,6 +15,10 @@ on:
|
|||
required: false
|
||||
default: ''
|
||||
|
||||
concurrency:
|
||||
group: docker-build
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
|
@ -28,33 +39,28 @@ jobs:
|
|||
- name: Read app version and calculate next Docker build version
|
||||
id: tag_version
|
||||
run: |
|
||||
# Read version from pyproject.toml
|
||||
APP_VERSION=$(grep -E '^version = ' surfsense_backend/pyproject.toml | sed 's/version = "\(.*\)"/\1/')
|
||||
echo "App version from pyproject.toml: $APP_VERSION"
|
||||
|
||||
|
||||
if [ -z "$APP_VERSION" ]; then
|
||||
echo "Error: Could not read version from surfsense_backend/pyproject.toml"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch all tags
|
||||
|
||||
git fetch --tags
|
||||
|
||||
# Find the latest docker build tag for this app version (format: APP_VERSION.BUILD_NUMBER)
|
||||
# Tags follow pattern: 0.0.11.1, 0.0.11.2, etc.
|
||||
|
||||
LATEST_BUILD_TAG=$(git tag --list "${APP_VERSION}.*" --sort='-v:refname' | head -n 1)
|
||||
|
||||
|
||||
if [ -z "$LATEST_BUILD_TAG" ]; then
|
||||
echo "No previous Docker build tag found for version ${APP_VERSION}. Starting with ${APP_VERSION}.1"
|
||||
NEXT_VERSION="${APP_VERSION}.1"
|
||||
else
|
||||
echo "Latest Docker build tag found: $LATEST_BUILD_TAG"
|
||||
# Extract the build number (4th component)
|
||||
BUILD_NUMBER=$(echo "$LATEST_BUILD_TAG" | rev | cut -d. -f1 | rev)
|
||||
NEXT_BUILD=$((BUILD_NUMBER + 1))
|
||||
NEXT_VERSION="${APP_VERSION}.${NEXT_BUILD}"
|
||||
fi
|
||||
|
||||
|
||||
echo "Calculated next Docker version: $NEXT_VERSION"
|
||||
echo "next_version=$NEXT_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
|
|
@ -78,67 +84,35 @@ jobs:
|
|||
git ls-remote --tags origin | grep "refs/tags/${{ steps.tag_version.outputs.next_version }}" || (echo "Tag push verification failed!" && exit 1)
|
||||
echo "Tag successfully pushed."
|
||||
|
||||
# Build for AMD64 on native x64 runner
|
||||
build_amd64:
|
||||
runs-on: ubuntu-latest
|
||||
build:
|
||||
needs: tag_release
|
||||
runs-on: ${{ matrix.os }}
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform: [linux/amd64, linux/arm64]
|
||||
image: [backend, web]
|
||||
include:
|
||||
- platform: linux/amd64
|
||||
suffix: amd64
|
||||
os: ubuntu-latest
|
||||
- platform: linux/arm64
|
||||
suffix: arm64
|
||||
os: ubuntu-24.04-arm
|
||||
- image: backend
|
||||
name: surfsense-backend
|
||||
context: ./surfsense_backend
|
||||
file: ./surfsense_backend/Dockerfile
|
||||
- image: web
|
||||
name: surfsense-web
|
||||
context: ./surfsense_web
|
||||
file: ./surfsense_web/Dockerfile
|
||||
env:
|
||||
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/surfsense
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ matrix.name }}
|
||||
|
||||
- name: Set lowercase image name
|
||||
id: image
|
||||
run: echo "name=${REGISTRY_IMAGE,,}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
docker system prune -af
|
||||
|
||||
- name: Build and push AMD64 image
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile.allinone
|
||||
push: true
|
||||
tags: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}-amd64
|
||||
platforms: linux/amd64
|
||||
cache-from: type=gha,scope=amd64
|
||||
cache-to: type=gha,mode=max,scope=amd64
|
||||
provenance: false
|
||||
|
||||
# Build for ARM64 on native arm64 runner (no QEMU emulation!)
|
||||
build_arm64:
|
||||
runs-on: ubuntu-24.04-arm
|
||||
needs: tag_release
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/surfsense
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -165,28 +139,41 @@ jobs:
|
|||
sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
|
||||
docker system prune -af
|
||||
|
||||
- name: Build and push ARM64 image
|
||||
- name: Build and push ${{ matrix.name }} (${{ matrix.suffix }})
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile.allinone
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
push: true
|
||||
tags: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}-arm64
|
||||
platforms: linux/arm64
|
||||
cache-from: type=gha,scope=arm64
|
||||
cache-to: type=gha,mode=max,scope=arm64
|
||||
tags: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}-${{ matrix.suffix }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha,scope=${{ matrix.image }}-${{ matrix.suffix }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.image }}-${{ matrix.suffix }}
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
|
||||
|
||||
# Create multi-arch manifest combining both platform images
|
||||
create_manifest:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [tag_release, build_amd64, build_arm64]
|
||||
needs: [tag_release, build]
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- name: surfsense-backend
|
||||
- name: surfsense-web
|
||||
env:
|
||||
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/surfsense
|
||||
REGISTRY_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ matrix.name }}
|
||||
|
||||
steps:
|
||||
- name: Set lowercase image name
|
||||
id: image
|
||||
|
|
@ -203,28 +190,31 @@ jobs:
|
|||
run: |
|
||||
VERSION_TAG="${{ needs.tag_release.outputs.new_tag }}"
|
||||
IMAGE="${{ steps.image.outputs.name }}"
|
||||
|
||||
# Create manifest for version tag
|
||||
APP_VERSION=$(echo "$VERSION_TAG" | rev | cut -d. -f2- | rev)
|
||||
|
||||
docker manifest create ${IMAGE}:${VERSION_TAG} \
|
||||
${IMAGE}:${VERSION_TAG}-amd64 \
|
||||
${IMAGE}:${VERSION_TAG}-arm64
|
||||
|
||||
|
||||
docker manifest push ${IMAGE}:${VERSION_TAG}
|
||||
|
||||
# Create/update latest tag if on default branch
|
||||
|
||||
if [[ "${{ github.ref }}" == "refs/heads/${{ github.event.repository.default_branch }}" ]] || [[ "${{ github.event.inputs.branch }}" == "${{ github.event.repository.default_branch }}" ]]; then
|
||||
docker manifest create ${IMAGE}:${APP_VERSION} \
|
||||
${IMAGE}:${VERSION_TAG}-amd64 \
|
||||
${IMAGE}:${VERSION_TAG}-arm64
|
||||
|
||||
docker manifest push ${IMAGE}:${APP_VERSION}
|
||||
|
||||
docker manifest create ${IMAGE}:latest \
|
||||
${IMAGE}:${VERSION_TAG}-amd64 \
|
||||
${IMAGE}:${VERSION_TAG}-arm64
|
||||
|
||||
|
||||
docker manifest push ${IMAGE}:latest
|
||||
fi
|
||||
|
||||
- name: Clean up architecture-specific tags (optional)
|
||||
continue-on-error: true
|
||||
- name: Summary
|
||||
run: |
|
||||
# Note: GHCR doesn't support tag deletion via API easily
|
||||
# The arch-specific tags will remain but users should use the main tags
|
||||
echo "Multi-arch manifest created successfully!"
|
||||
echo "Users should pull: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}"
|
||||
echo "Or for latest: ${{ steps.image.outputs.name }}:latest"
|
||||
echo "Multi-arch manifest created for ${{ matrix.name }}!"
|
||||
echo "Versioned: ${{ steps.image.outputs.name }}:${{ needs.tag_release.outputs.new_tag }}"
|
||||
echo "App version: ${{ steps.image.outputs.name }}:$(echo '${{ needs.tag_release.outputs.new_tag }}' | rev | cut -d. -f2- | rev)"
|
||||
echo "Latest: ${{ steps.image.outputs.name }}:latest"
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -5,3 +5,4 @@ node_modules/
|
|||
.ruff_cache/
|
||||
.venv
|
||||
.pnpm-store
|
||||
.DS_Store
|
||||
|
|
|
|||
|
|
@ -1,285 +0,0 @@
|
|||
# SurfSense All-in-One Docker Image
|
||||
# This image bundles PostgreSQL+pgvector, Redis, Electric SQL, Backend, and Frontend
|
||||
# Usage: docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense ghcr.io/modsetter/surfsense:latest
|
||||
#
|
||||
# Included Services (all run locally by default):
|
||||
# - PostgreSQL 14 + pgvector (vector database)
|
||||
# - Redis (task queue)
|
||||
# - Electric SQL (real-time sync)
|
||||
# - Docling (document processing, CPU-only, OCR disabled)
|
||||
# - Kokoro TTS (local text-to-speech for podcasts)
|
||||
# - Faster-Whisper (local speech-to-text for audio files)
|
||||
# - Playwright Chromium (web scraping)
|
||||
#
|
||||
# Note: This is the CPU-only version. A :cuda tagged image with GPU support
|
||||
# will be available in the future for faster AI inference.
|
||||
|
||||
# ====================
|
||||
# Stage 1: Get Electric SQL Binary
|
||||
# ====================
|
||||
FROM electricsql/electric:latest AS electric-builder
|
||||
|
||||
# ====================
|
||||
# Stage 2: Build Frontend
|
||||
# ====================
|
||||
FROM node:20-alpine AS frontend-builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install pnpm
|
||||
RUN corepack enable pnpm
|
||||
|
||||
# Copy package files
|
||||
COPY surfsense_web/package.json surfsense_web/pnpm-lock.yaml* ./
|
||||
COPY surfsense_web/source.config.ts ./
|
||||
COPY surfsense_web/content ./content
|
||||
|
||||
# Install dependencies (skip postinstall which requires all source files)
|
||||
RUN pnpm install --frozen-lockfile --ignore-scripts
|
||||
|
||||
# Copy source
|
||||
COPY surfsense_web/ ./
|
||||
|
||||
# Run fumadocs-mdx postinstall now that source files are available
|
||||
RUN pnpm fumadocs-mdx
|
||||
|
||||
# Build with placeholder values that will be replaced at runtime
|
||||
# These unique strings allow runtime substitution via entrypoint script
|
||||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__
|
||||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__
|
||||
ENV NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__
|
||||
ENV NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__
|
||||
ENV NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__
|
||||
ENV NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__
|
||||
|
||||
# Build
|
||||
RUN pnpm run build
|
||||
|
||||
# ====================
|
||||
# Stage 3: Runtime Image
|
||||
# ====================
|
||||
FROM ubuntu:22.04 AS runtime
|
||||
|
||||
# Prevent interactive prompts
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# PostgreSQL
|
||||
postgresql-14 \
|
||||
postgresql-contrib-14 \
|
||||
# Build tools for pgvector
|
||||
build-essential \
|
||||
postgresql-server-dev-14 \
|
||||
git \
|
||||
# Redis
|
||||
redis-server \
|
||||
# Node.js prerequisites
|
||||
curl \
|
||||
ca-certificates \
|
||||
gnupg \
|
||||
# Backend dependencies
|
||||
gcc \
|
||||
wget \
|
||||
unzip \
|
||||
dos2unix \
|
||||
# For PPAs
|
||||
software-properties-common \
|
||||
# ============================
|
||||
# Local TTS (Kokoro) dependencies
|
||||
# ============================
|
||||
espeak-ng \
|
||||
libespeak-ng1 \
|
||||
# ============================
|
||||
# Local STT (Faster-Whisper) dependencies
|
||||
# ============================
|
||||
ffmpeg \
|
||||
# ============================
|
||||
# Audio processing (soundfile)
|
||||
# ============================
|
||||
libsndfile1 \
|
||||
# ============================
|
||||
# Image/OpenCV dependencies (for Docling)
|
||||
# ============================
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender1 \
|
||||
# ============================
|
||||
# Playwright browser dependencies
|
||||
# ============================
|
||||
libnspr4 \
|
||||
libnss3 \
|
||||
libatk1.0-0 \
|
||||
libatk-bridge2.0-0 \
|
||||
libcups2 \
|
||||
libxkbcommon0 \
|
||||
libatspi2.0-0 \
|
||||
libxcomposite1 \
|
||||
libxdamage1 \
|
||||
libxrandr2 \
|
||||
libgbm1 \
|
||||
libcairo2 \
|
||||
libpango-1.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Pandoc 3.x from GitHub (apt ships 2.9 which has broken table rendering).
|
||||
RUN ARCH=$(dpkg --print-architecture) && \
|
||||
wget -qO /tmp/pandoc.deb "https://github.com/jgm/pandoc/releases/download/3.9/pandoc-3.9-1-${ARCH}.deb" && \
|
||||
dpkg -i /tmp/pandoc.deb && \
|
||||
rm /tmp/pandoc.deb
|
||||
|
||||
|
||||
# Install Node.js 20.x (for running frontend)
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python 3.12 from deadsnakes PPA
|
||||
RUN add-apt-repository ppa:deadsnakes/ppa -y \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
python3.12-venv \
|
||||
python3.12-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set Python 3.12 as default
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
|
||||
|
||||
# Install pip for Python 3.12
|
||||
RUN python3.12 -m ensurepip --upgrade \
|
||||
&& python3.12 -m pip install --upgrade pip
|
||||
|
||||
# Install supervisor via pip (system package incompatible with Python 3.12)
|
||||
RUN pip install --no-cache-dir supervisor
|
||||
|
||||
# Build and install pgvector
|
||||
RUN cd /tmp \
|
||||
&& git clone --branch v0.7.4 https://github.com/pgvector/pgvector.git \
|
||||
&& cd pgvector \
|
||||
&& make \
|
||||
&& make install \
|
||||
&& rm -rf /tmp/pgvector
|
||||
|
||||
# Update certificates
|
||||
RUN update-ca-certificates
|
||||
|
||||
# Create data directories
|
||||
RUN mkdir -p /data/postgres /data/redis /data/surfsense \
|
||||
&& chown -R postgres:postgres /data/postgres
|
||||
|
||||
# ====================
|
||||
# Copy Frontend Build
|
||||
# ====================
|
||||
WORKDIR /app/frontend
|
||||
|
||||
# Copy only the standalone build (not node_modules)
|
||||
COPY --from=frontend-builder /app/.next/standalone ./
|
||||
COPY --from=frontend-builder /app/.next/static ./.next/static
|
||||
COPY --from=frontend-builder /app/public ./public
|
||||
|
||||
COPY surfsense_web/content/docs /app/surfsense_web/content/docs
|
||||
|
||||
# ====================
|
||||
# Copy Electric SQL Release
|
||||
# ====================
|
||||
COPY --from=electric-builder /app /app/electric-release
|
||||
|
||||
# ====================
|
||||
# Setup Backend
|
||||
# ====================
|
||||
WORKDIR /app/backend
|
||||
|
||||
# Copy backend dependency files
|
||||
COPY surfsense_backend/pyproject.toml surfsense_backend/uv.lock ./
|
||||
|
||||
# Install PyTorch CPU-only (Docling needs it but OCR is disabled, no GPU needed)
|
||||
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Install python dependencies
|
||||
RUN pip install --no-cache-dir certifi pip-system-certs uv \
|
||||
&& uv pip install --system --no-cache-dir -e .
|
||||
|
||||
# Set SSL environment variables
|
||||
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") \
|
||||
&& echo "export SSL_CERT_FILE=$CERTIFI_PATH" >> /etc/profile.d/ssl.sh \
|
||||
&& echo "export REQUESTS_CA_BUNDLE=$CERTIFI_PATH" >> /etc/profile.d/ssl.sh
|
||||
|
||||
# Note: EasyOCR models NOT downloaded - OCR is disabled in docling_service.py
|
||||
# GPU support will be added in a future :cuda tagged image
|
||||
|
||||
# Install Playwright browsers
|
||||
RUN pip install --no-cache-dir playwright \
|
||||
&& playwright install chromium \
|
||||
&& rm -rf /root/.cache/ms-playwright/ffmpeg*
|
||||
|
||||
# Copy backend source
|
||||
COPY surfsense_backend/ ./
|
||||
|
||||
# ====================
|
||||
# Configuration
|
||||
# ====================
|
||||
WORKDIR /app
|
||||
|
||||
# Copy supervisor configuration
|
||||
COPY scripts/docker/supervisor-allinone.conf /etc/supervisor/conf.d/surfsense.conf
|
||||
|
||||
# Copy entrypoint script
|
||||
COPY scripts/docker/entrypoint-allinone.sh /app/entrypoint.sh
|
||||
RUN dos2unix /app/entrypoint.sh && chmod +x /app/entrypoint.sh
|
||||
|
||||
# PostgreSQL initialization script
|
||||
COPY scripts/docker/init-postgres.sh /app/init-postgres.sh
|
||||
RUN dos2unix /app/init-postgres.sh && chmod +x /app/init-postgres.sh
|
||||
|
||||
# Clean up build dependencies to reduce image size
|
||||
RUN apt-get purge -y build-essential postgresql-server-dev-14 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
# Environment variables with defaults
|
||||
ENV POSTGRES_USER=surfsense
|
||||
ENV POSTGRES_PASSWORD=surfsense
|
||||
ENV POSTGRES_DB=surfsense
|
||||
ENV DATABASE_URL=postgresql+asyncpg://surfsense:surfsense@localhost:5432/surfsense
|
||||
ENV CELERY_BROKER_URL=redis://localhost:6379/0
|
||||
ENV CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||
ENV CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
ENV PYTHONPATH=/app/backend
|
||||
ENV NEXT_FRONTEND_URL=http://localhost:3000
|
||||
ENV AUTH_TYPE=LOCAL
|
||||
ENV ETL_SERVICE=DOCLING
|
||||
ENV EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
# Frontend configuration (can be overridden at runtime)
|
||||
# These are injected into the Next.js build at container startup
|
||||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
|
||||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
ENV NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
|
||||
# Electric SQL configuration (ELECTRIC_DATABASE_URL is built dynamically by entrypoint from these values)
|
||||
ENV ELECTRIC_DB_USER=electric
|
||||
ENV ELECTRIC_DB_PASSWORD=electric_password
|
||||
# Note: ELECTRIC_DATABASE_URL is NOT set here - entrypoint builds it dynamically from ELECTRIC_DB_USER/PASSWORD
|
||||
ENV ELECTRIC_INSECURE=true
|
||||
ENV ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
ENV ELECTRIC_PORT=5133
|
||||
ENV PORT=5133
|
||||
ENV NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
|
||||
ENV NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
|
||||
# Data volume
|
||||
VOLUME ["/data"]
|
||||
|
||||
# Expose ports (Frontend: 3000, Backend: 8000, Electric: 5133)
|
||||
EXPOSE 3000 8000 5133
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
|
||||
CMD curl -f http://localhost:3000 || exit 1
|
||||
|
||||
# Run entrypoint
|
||||
CMD ["/app/entrypoint.sh"]
|
||||
14
README.es.md
14
README.es.md
|
|
@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
Ejecuta SurfSense en tu propia infraestructura para control total de datos y privacidad.
|
||||
|
||||
**Inicio Rápido (Docker en un solo comando):**
|
||||
**Requisitos previos:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) debe estar instalado y en ejecución.
|
||||
|
||||
#### Para usuarios de Linux/MacOS:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
|
||||
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
Después de iniciar, abre [http://localhost:3000](http://localhost:3000) en tu navegador.
|
||||
#### Para usuarios de Windows:
|
||||
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
El script de instalación configura [Watchtower](https://github.com/nicholas-fedor/watchtower) automáticamente para actualizaciones diarias. Para omitirlo, agrega la bandera `--no-watchtower`.
|
||||
|
||||
Para Docker Compose, instalación manual y otras opciones de despliegue, consulta la [documentación](https://www.surfsense.com/docs/).
|
||||
|
||||
|
|
|
|||
14
README.hi.md
14
README.hi.md
|
|
@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
पूर्ण डेटा नियंत्रण और गोपनीयता के लिए SurfSense को अपने स्वयं के बुनियादी ढांचे पर चलाएं।
|
||||
|
||||
**त्वरित शुरुआत (Docker एक कमांड में):**
|
||||
**आवश्यकताएँ:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) इंस्टॉल और चालू होना चाहिए।
|
||||
|
||||
#### Linux/MacOS उपयोगकर्ताओं के लिए:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
|
||||
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
शुरू करने के बाद, अपने ब्राउज़र में [http://localhost:3000](http://localhost:3000) खोलें।
|
||||
#### Windows उपयोगकर्ताओं के लिए:
|
||||
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
इंस्टॉल स्क्रिप्ट दैनिक ऑटो-अपडेट के लिए स्वचालित रूप से [Watchtower](https://github.com/nicholas-fedor/watchtower) सेटअप करती है। इसे छोड़ने के लिए, `--no-watchtower` फ्लैग जोड़ें।
|
||||
|
||||
Docker Compose, मैनुअल इंस्टॉलेशन और अन्य डिप्लॉयमेंट विकल्पों के लिए, [डॉक्स](https://www.surfsense.com/docs/) देखें।
|
||||
|
||||
|
|
|
|||
16
README.md
16
README.md
|
|
@ -81,21 +81,23 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
Run SurfSense on your own infrastructure for full data control and privacy.
|
||||
|
||||
**Quick Start (Docker one-liner):**
|
||||
**Prerequisites:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) must be installed and running.
|
||||
|
||||
#### For Linux/MacOS users:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
|
||||
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
After starting, open [http://localhost:3000](http://localhost:3000) in your browser.
|
||||
|
||||
**Update (Automatic updates with Watchtower):**
|
||||
#### For Windows users:
|
||||
|
||||
```bash
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock nickfedor/watchtower --run-once surfsense
|
||||
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
For Docker Compose, manual installation, and other deployment options, check the [docs](https://www.surfsense.com/docs/).
|
||||
The install script sets up [Watchtower](https://github.com/nicholas-fedor/watchtower) automatically for daily auto-updates. To skip it, add the `--no-watchtower` flag.
|
||||
|
||||
For Docker Compose, manual installation, and other deployment options, see the [docs](https://www.surfsense.com/docs/).
|
||||
|
||||
### How to Realtime Collaborate (Beta)
|
||||
|
||||
|
|
|
|||
|
|
@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
Execute o SurfSense na sua própria infraestrutura para controle total de dados e privacidade.
|
||||
|
||||
**Início Rápido (Docker em um único comando):**
|
||||
**Pré-requisitos:** [Docker Desktop](https://www.docker.com/products/docker-desktop/) deve estar instalado e em execução.
|
||||
|
||||
#### Para usuários de Linux/MacOS:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
|
||||
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
Após iniciar, abra [http://localhost:3000](http://localhost:3000) no seu navegador.
|
||||
#### Para usuários do Windows:
|
||||
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
O script de instalação configura o [Watchtower](https://github.com/nicholas-fedor/watchtower) automaticamente para atualizações diárias. Para pular, adicione a flag `--no-watchtower`.
|
||||
|
||||
Para Docker Compose, instalação manual e outras opções de implantação, consulte a [documentação](https://www.surfsense.com/docs/).
|
||||
|
||||
|
|
|
|||
|
|
@ -81,13 +81,21 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
在您自己的基础设施上运行 SurfSense,实现完全的数据控制和隐私保护。
|
||||
|
||||
**快速开始(Docker 一行命令):**
|
||||
**前置条件:** 需要安装并运行 [Docker Desktop](https://www.docker.com/products/docker-desktop/)。
|
||||
|
||||
#### Linux/MacOS 用户:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense --restart unless-stopped ghcr.io/modsetter/surfsense:latest
|
||||
curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
启动后,在浏览器中打开 [http://localhost:3000](http://localhost:3000)。
|
||||
#### Windows 用户:
|
||||
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
安装脚本会自动配置 [Watchtower](https://github.com/nicholas-fedor/watchtower) 以实现每日自动更新。如需跳过,请添加 `--no-watchtower` 参数。
|
||||
|
||||
如需 Docker Compose、手动安装及其他部署方式,请查看[文档](https://www.surfsense.com/docs/)。
|
||||
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
# SurfSense Quick Start Docker Compose
|
||||
#
|
||||
# This is a simplified docker-compose for quick local deployment using pre-built images.
|
||||
# For production or customized deployments, use the main docker-compose.yml
|
||||
#
|
||||
# Usage:
|
||||
# 1. (Optional) Create a .env file with your configuration
|
||||
# 2. Run: docker compose -f docker-compose.quickstart.yml up -d
|
||||
# 3. Access SurfSense at http://localhost:3000
|
||||
#
|
||||
# All Environment Variables are Optional:
|
||||
# - SECRET_KEY: JWT secret key (auto-generated and persisted if not set)
|
||||
# - EMBEDDING_MODEL: Embedding model to use (default: sentence-transformers/all-MiniLM-L6-v2)
|
||||
# - ETL_SERVICE: Document parsing service - DOCLING, UNSTRUCTURED, or LLAMACLOUD (default: DOCLING)
|
||||
# - TTS_SERVICE: Text-to-speech service for podcasts (default: local/kokoro)
|
||||
# - STT_SERVICE: Speech-to-text service with model size (default: local/base)
|
||||
# - FIRECRAWL_API_KEY: For web crawling features
|
||||
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
# All-in-one SurfSense container
|
||||
surfsense:
|
||||
image: ghcr.io/modsetter/surfsense:latest
|
||||
container_name: surfsense
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
- "${BACKEND_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- surfsense-data:/data
|
||||
environment:
|
||||
# Authentication (auto-generated if not set)
|
||||
- SECRET_KEY=${SECRET_KEY:-}
|
||||
|
||||
# Auth Configuration
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
# AI/ML Configuration
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- RERANKERS_ENABLED=${RERANKERS_ENABLED:-FALSE}
|
||||
- RERANKERS_MODEL_NAME=${RERANKERS_MODEL_NAME:-}
|
||||
- RERANKERS_MODEL_TYPE=${RERANKERS_MODEL_TYPE:-}
|
||||
|
||||
# Document Processing
|
||||
- ETL_SERVICE=${ETL_SERVICE:-DOCLING}
|
||||
- UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
|
||||
- LLAMA_CLOUD_API_KEY=${LLAMA_CLOUD_API_KEY:-}
|
||||
|
||||
# Audio Services
|
||||
- TTS_SERVICE=${TTS_SERVICE:-local/kokoro}
|
||||
- TTS_SERVICE_API_KEY=${TTS_SERVICE_API_KEY:-}
|
||||
- STT_SERVICE=${STT_SERVICE:-local/base}
|
||||
- STT_SERVICE_API_KEY=${STT_SERVICE_API_KEY:-}
|
||||
|
||||
# Web Crawling
|
||||
- FIRECRAWL_API_KEY=${FIRECRAWL_API_KEY:-}
|
||||
|
||||
# Optional Features
|
||||
- REGISTRATION_ENABLED=${REGISTRATION_ENABLED:-TRUE}
|
||||
- SCHEDULE_CHECKER_INTERVAL=${SCHEDULE_CHECKER_INTERVAL:-1m}
|
||||
|
||||
# LangSmith Observability (optional)
|
||||
- LANGSMITH_TRACING=${LANGSMITH_TRACING:-false}
|
||||
- LANGSMITH_ENDPOINT=${LANGSMITH_ENDPOINT:-}
|
||||
- LANGSMITH_API_KEY=${LANGSMITH_API_KEY:-}
|
||||
- LANGSMITH_PROJECT=${LANGSMITH_PROJECT:-}
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000", "&&", "curl", "-f", "http://localhost:8000/docs"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 120s
|
||||
|
||||
volumes:
|
||||
surfsense-data:
|
||||
name: surfsense-data
|
||||
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
version: "3.8"
|
||||
|
||||
services:
|
||||
db:
|
||||
image: ankane/pgvector:latest
|
||||
ports:
|
||||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./scripts/docker/postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/docker/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-postgres}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-surfsense}
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
|
||||
pgadmin:
|
||||
image: dpage/pgadmin4
|
||||
ports:
|
||||
- "${PGADMIN_PORT:-5050}:80"
|
||||
environment:
|
||||
- PGADMIN_DEFAULT_EMAIL=${PGADMIN_DEFAULT_EMAIL:-admin@surfsense.com}
|
||||
- PGADMIN_DEFAULT_PASSWORD=${PGADMIN_DEFAULT_PASSWORD:-surfsense}
|
||||
volumes:
|
||||
- pgadmin_data:/var/lib/pgadmin
|
||||
depends_on:
|
||||
- db
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "${REDIS_PORT:-6379}:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
|
||||
backend:
|
||||
build: ./surfsense_backend
|
||||
# image: ghcr.io/modsetter/surfsense_backend:latest
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- ./surfsense_backend/app:/app/app
|
||||
- shared_temp:/tmp
|
||||
# Uncomment and edit the line below to enable Obsidian vault indexing
|
||||
# - /path/to/your/obsidian/vault:/obsidian-vault:ro
|
||||
env_file:
|
||||
- ./surfsense_backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
|
||||
- CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
|
||||
- REDIS_APP_URL=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# Queue name isolation - prevents task collision if Redis is shared with other apps
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- UVICORN_LOOP=asyncio
|
||||
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
|
||||
- LANGCHAIN_TRACING_V2=false
|
||||
- LANGSMITH_TRACING=false
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
|
||||
# Run these services separately in production
|
||||
# celery_worker:
|
||||
# build: ./surfsense_backend
|
||||
# # image: ghcr.io/modsetter/surfsense_backend:latest
|
||||
# command: celery -A app.celery_app worker --loglevel=info --concurrency=1 --pool=solo
|
||||
# volumes:
|
||||
# - ./surfsense_backend:/app
|
||||
# - shared_temp:/tmp
|
||||
# env_file:
|
||||
# - ./surfsense_backend/.env
|
||||
# environment:
|
||||
# - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
|
||||
# - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# - PYTHONPATH=/app
|
||||
# depends_on:
|
||||
# - db
|
||||
# - redis
|
||||
# - backend
|
||||
|
||||
# celery_beat:
|
||||
# build: ./surfsense_backend
|
||||
# # image: ghcr.io/modsetter/surfsense_backend:latest
|
||||
# command: celery -A app.celery_app beat --loglevel=info
|
||||
# volumes:
|
||||
# - ./surfsense_backend:/app
|
||||
# - shared_temp:/tmp
|
||||
# env_file:
|
||||
# - ./surfsense_backend/.env
|
||||
# environment:
|
||||
# - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
|
||||
# - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# - PYTHONPATH=/app
|
||||
# depends_on:
|
||||
# - db
|
||||
# - redis
|
||||
# - celery_worker
|
||||
|
||||
# flower:
|
||||
# build: ./surfsense_backend
|
||||
# # image: ghcr.io/modsetter/surfsense_backend:latest
|
||||
# command: celery -A app.celery_app flower --port=5555
|
||||
# ports:
|
||||
# - "${FLOWER_PORT:-5555}:5555"
|
||||
# env_file:
|
||||
# - ./surfsense_backend/.env
|
||||
# environment:
|
||||
# - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
|
||||
# - PYTHONPATH=/app
|
||||
# depends_on:
|
||||
# - redis
|
||||
# - celery_worker
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:latest
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5133}:3000"
|
||||
environment:
|
||||
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${POSTGRES_HOST:-db}:${POSTGRES_PORT:-5432}/${POSTGRES_DB:-surfsense}?sslmode=disable}
|
||||
- ELECTRIC_INSECURE=true
|
||||
- ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
restart: unless-stopped
|
||||
# depends_on:
|
||||
# - db
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
frontend:
|
||||
build:
|
||||
context: ./surfsense_web
|
||||
# image: ghcr.io/modsetter/surfsense_ui:latest
|
||||
args:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
env_file:
|
||||
- ./surfsense_web/.env
|
||||
environment:
|
||||
- NEXT_PUBLIC_ELECTRIC_URL=${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
|
||||
- NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
depends_on:
|
||||
- backend
|
||||
- electric
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
pgadmin_data:
|
||||
redis_data:
|
||||
shared_temp:
|
||||
256
docker/.env.example
Normal file
256
docker/.env.example
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
# ==============================================================================
|
||||
# SurfSense Docker Configuration
|
||||
# ==============================================================================
|
||||
# Database, Redis, and internal service wiring are handled automatically.
|
||||
# ==============================================================================
|
||||
|
||||
# SurfSense version (use "latest", a clean version like "0.0.14", or a specific build like "0.0.14.1")
|
||||
SURFSENSE_VERSION=latest
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Core Settings
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# REQUIRED: Generate a secret key with: openssl rand -base64 32
|
||||
SECRET_KEY=replace_me_with_a_random_string
|
||||
|
||||
# Auth type: LOCAL (email/password) or GOOGLE (OAuth)
|
||||
AUTH_TYPE=LOCAL
|
||||
|
||||
# Allow new user registrations (TRUE or FALSE)
|
||||
# REGISTRATION_ENABLED=TRUE
|
||||
|
||||
# Document parsing service: DOCLING, UNSTRUCTURED, or LLAMACLOUD
|
||||
ETL_SERVICE=DOCLING
|
||||
|
||||
# Embedding model for vector search
|
||||
# Local: sentence-transformers/all-MiniLM-L6-v2
|
||||
# OpenAI: openai://text-embedding-ada-002 (set OPENAI_API_KEY below)
|
||||
# Cohere: cohere://embed-english-light-v3.0 (set COHERE_API_KEY below)
|
||||
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Ports (change to avoid conflicts with other services on your machine)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# BACKEND_PORT=8000
|
||||
# FRONTEND_PORT=3000
|
||||
# ELECTRIC_PORT=5133
|
||||
# FLOWER_PORT=5555
|
||||
|
||||
# ==============================================================================
|
||||
# DEV COMPOSE ONLY (docker-compose.dev.yml)
|
||||
# You only need them only if you are running `docker-compose.dev.yml`.
|
||||
# ==============================================================================
|
||||
|
||||
# -- pgAdmin (database GUI) --
|
||||
# PGADMIN_PORT=5050
|
||||
# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
# PGADMIN_DEFAULT_PASSWORD=surfsense
|
||||
|
||||
# -- Redis exposed port (dev only; Redis is internal-only in prod) --
|
||||
# REDIS_PORT=6379
|
||||
|
||||
# -- Frontend Build Args --
|
||||
# In dev, the frontend is built from source and these are passed as build args.
|
||||
# In prod, they are automatically derived from AUTH_TYPE, ETL_SERVICE, and the port settings above.
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
# NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
|
||||
# NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Custom Domain / Reverse Proxy
|
||||
# ------------------------------------------------------------------------------
|
||||
# ONLY set these if you are serving SurfSense on a real domain via a reverse
|
||||
# proxy (e.g. Caddy, Nginx, Cloudflare Tunnel).
|
||||
# For standard localhost deployments, leave all of these commented out —
|
||||
# they are automatically derived from the port settings above.
|
||||
#
|
||||
# NEXT_FRONTEND_URL=https://app.yourdomain.com
|
||||
# BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_ELECTRIC_URL=https://electric.yourdomain.com
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Database (defaults work out of the box, change for security)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# DB_USER=surfsense
|
||||
# DB_PASSWORD=surfsense
|
||||
# DB_NAME=surfsense
|
||||
# DB_HOST=db
|
||||
# DB_PORT=5432
|
||||
|
||||
# SSL mode for database connections: disable, require, verify-ca, verify-full
|
||||
# DB_SSLMODE=disable
|
||||
|
||||
# Full DATABASE_URL override — when set, takes precedence over the individual
|
||||
# DB_USER / DB_PASSWORD / DB_NAME / DB_HOST / DB_PORT settings above.
|
||||
# Use this for managed databases (AWS RDS, GCP Cloud SQL, Supabase, etc.)
|
||||
# DATABASE_URL=postgresql+asyncpg://user:password@your-rds-host:5432/surfsense?sslmode=require
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Redis (defaults work out of the box)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Full Redis URL override for Celery broker, result backend, and app cache.
|
||||
# Use this for managed Redis (AWS ElastiCache, Redis Cloud, etc.)
|
||||
# Supports auth: redis://:password@host:port/0
|
||||
# Supports TLS: rediss://:password@host:6380/0
|
||||
# REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Electric SQL (real-time sync credentials)
|
||||
# ------------------------------------------------------------------------------
|
||||
# These must match on the db, backend, and electric services.
|
||||
# Change for security; defaults work out of the box.
|
||||
|
||||
# ELECTRIC_DB_USER=electric
|
||||
# ELECTRIC_DB_PASSWORD=electric_password
|
||||
# Full override for the Electric → Postgres connection URL.
|
||||
# Leave commented out to use the Docker-managed `db` container (default).
|
||||
# Uncomment and set `db` to `host.docker.internal` when pointing Electric at a local Postgres instance (e.g. Postgres.app on macOS):
|
||||
# ELECTRIC_DATABASE_URL=postgresql://electric:electric_password@db:5432/surfsense?sslmode=disable
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Local Kokoro TTS (default) or LiteLLM provider
|
||||
TTS_SERVICE=local/kokoro
|
||||
# TTS_SERVICE_API_KEY=
|
||||
# TTS_SERVICE_API_BASE=
|
||||
|
||||
# Local Faster-Whisper STT: local/MODEL_SIZE (tiny, base, small, medium, large-v3)
|
||||
STT_SERVICE=local/base
|
||||
# Or use LiteLLM: openai/whisper-1
|
||||
# STT_SERVICE_API_KEY=
|
||||
# STT_SERVICE_API_BASE=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Rerankers (optional, disabled by default)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# RERANKERS_ENABLED=TRUE
|
||||
# RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
|
||||
# RERANKERS_MODEL_TYPE=flashrank
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Google OAuth (only if AUTH_TYPE=GOOGLE)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# GOOGLE_OAUTH_CLIENT_ID=
|
||||
# GOOGLE_OAUTH_CLIENT_SECRET=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Connector OAuth Keys (uncomment connectors you want to use)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# -- Google Connectors --
|
||||
# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
|
||||
# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
|
||||
# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback
|
||||
|
||||
# -- Notion --
|
||||
# NOTION_CLIENT_ID=
|
||||
# NOTION_CLIENT_SECRET=
|
||||
# NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback
|
||||
|
||||
# -- Slack --
|
||||
# SLACK_CLIENT_ID=
|
||||
# SLACK_CLIENT_SECRET=
|
||||
# SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback
|
||||
|
||||
# -- Discord --
|
||||
# DISCORD_CLIENT_ID=
|
||||
# DISCORD_CLIENT_SECRET=
|
||||
# DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback
|
||||
# DISCORD_BOT_TOKEN=
|
||||
|
||||
# -- Atlassian (Jira & Confluence) --
|
||||
# ATLASSIAN_CLIENT_ID=
|
||||
# ATLASSIAN_CLIENT_SECRET=
|
||||
# JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback
|
||||
# CONFLUENCE_REDIRECT_URI=http://localhost:8000/api/v1/auth/confluence/connector/callback
|
||||
|
||||
# -- Linear --
|
||||
# LINEAR_CLIENT_ID=
|
||||
# LINEAR_CLIENT_SECRET=
|
||||
# LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback
|
||||
|
||||
# -- ClickUp --
|
||||
# CLICKUP_CLIENT_ID=
|
||||
# CLICKUP_CLIENT_SECRET=
|
||||
# CLICKUP_REDIRECT_URI=http://localhost:8000/api/v1/auth/clickup/connector/callback
|
||||
|
||||
# -- Airtable --
|
||||
# AIRTABLE_CLIENT_ID=
|
||||
# AIRTABLE_CLIENT_SECRET=
|
||||
# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
|
||||
|
||||
# -- Microsoft Teams --
|
||||
# TEAMS_CLIENT_ID=
|
||||
# TEAMS_CLIENT_SECRET=
|
||||
# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
|
||||
# -- Composio --
|
||||
# COMPOSIO_API_KEY=
|
||||
# COMPOSIO_ENABLED=TRUE
|
||||
# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Daytona Sandbox (optional — cloud code execution for the deep agent)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Set DAYTONA_SANDBOX_ENABLED=TRUE and provide credentials to give the agent
|
||||
# an isolated code execution environment via the Daytona cloud API.
|
||||
# DAYTONA_SANDBOX_ENABLED=FALSE
|
||||
# DAYTONA_API_KEY=
|
||||
# DAYTONA_API_URL=https://app.daytona.io/api
|
||||
# DAYTONA_TARGET=us
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# External API Keys (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Firecrawl (web scraping)
|
||||
# FIRECRAWL_API_KEY=
|
||||
|
||||
# Unstructured (if ETL_SERVICE=UNSTRUCTURED)
|
||||
# UNSTRUCTURED_API_KEY=
|
||||
|
||||
# LlamaCloud (if ETL_SERVICE=LLAMACLOUD)
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Observability (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# LANGSMITH_TRACING=true
|
||||
# LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||
# LANGSMITH_API_KEY=
|
||||
# LANGSMITH_PROJECT=surfsense
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Advanced (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Periodic connector sync interval (default: 5m)
|
||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||
|
||||
# JWT token lifetimes
|
||||
# ACCESS_TOKEN_LIFETIME_SECONDS=86400
|
||||
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600
|
||||
|
||||
# Pages limit per user for ETL (default: unlimited)
|
||||
# PAGES_LIMIT=500
|
||||
|
||||
# Connector indexing lock TTL in seconds (default: 28800 = 8 hours)
|
||||
# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800
|
||||
|
||||
# Residential proxy for web crawling
|
||||
# RESIDENTIAL_PROXY_USERNAME=
|
||||
# RESIDENTIAL_PROXY_PASSWORD=
|
||||
# RESIDENTIAL_PROXY_HOSTNAME=
|
||||
# RESIDENTIAL_PROXY_LOCATION=
|
||||
# RESIDENTIAL_PROXY_TYPE=1
|
||||
206
docker/docker-compose.dev.yml
Normal file
206
docker/docker-compose.dev.yml
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
# =============================================================================
|
||||
# SurfSense — Development Docker Compose
|
||||
# =============================================================================
|
||||
# Usage (from repo root):
|
||||
# docker compose -f docker/docker-compose.dev.yml up --build
|
||||
#
|
||||
# This file builds from source and includes dev tools like pgAdmin.
|
||||
# For production with prebuilt images, use docker/docker-compose.yml instead.
|
||||
# =============================================================================
|
||||
|
||||
name: surfsense
|
||||
|
||||
services:
|
||||
db:
|
||||
image: pgvector/pgvector:pg17
|
||||
ports:
|
||||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
- POSTGRES_USER=${DB_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
|
||||
- POSTGRES_DB=${DB_NAME:-surfsense}
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
pgadmin:
|
||||
image: dpage/pgadmin4
|
||||
ports:
|
||||
- "${PGADMIN_PORT:-5050}:80"
|
||||
environment:
|
||||
- PGADMIN_DEFAULT_EMAIL=${PGADMIN_DEFAULT_EMAIL:-admin@surfsense.com}
|
||||
- PGADMIN_DEFAULT_PASSWORD=${PGADMIN_DEFAULT_PASSWORD:-surfsense}
|
||||
volumes:
|
||||
- pgadmin_data:/var/lib/pgadmin
|
||||
depends_on:
|
||||
- db
|
||||
|
||||
redis:
|
||||
image: redis:8-alpine
|
||||
ports:
|
||||
- "${REDIS_PORT:-6379}:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
backend:
|
||||
build: ../surfsense_backend
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- ../surfsense_backend/app:/app/app
|
||||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
|
||||
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- UVICORN_LOOP=asyncio
|
||||
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
|
||||
- LANGCHAIN_TRACING_V2=false
|
||||
- LANGSMITH_TRACING=false
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
# - DAYTONA_SANDBOX_ENABLED=TRUE
|
||||
# - DAYTONA_API_KEY=${DAYTONA_API_KEY:-}
|
||||
# - DAYTONA_API_URL=${DAYTONA_API_URL:-https://app.daytona.io/api}
|
||||
# - DAYTONA_TARGET=${DAYTONA_TARGET:-us}
|
||||
- SERVICE_ROLE=api
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
start_period: 200s
|
||||
|
||||
celery_worker:
|
||||
build: ../surfsense_backend
|
||||
volumes:
|
||||
- ../surfsense_backend/app:/app/app
|
||||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
|
||||
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- SERVICE_ROLE=worker
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
backend:
|
||||
condition: service_healthy
|
||||
|
||||
celery_beat:
|
||||
build: ../surfsense_backend
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
- CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- SERVICE_ROLE=beat
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
celery_worker:
|
||||
condition: service_started
|
||||
|
||||
# flower:
|
||||
# build: ../surfsense_backend
|
||||
# ports:
|
||||
# - "${FLOWER_PORT:-5555}:5555"
|
||||
# env_file:
|
||||
# - ../surfsense_backend/.env
|
||||
# environment:
|
||||
# - CELERY_BROKER_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
# - CELERY_RESULT_BACKEND=${REDIS_URL:-redis://redis:6379/0}
|
||||
# - PYTHONPATH=/app
|
||||
# command: celery -A app.celery_app flower --port=5555
|
||||
# depends_on:
|
||||
# - redis
|
||||
# - celery_worker
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:1.4.10
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5133}:3000"
|
||||
# depends_on:
|
||||
# - db
|
||||
environment:
|
||||
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ELECTRIC_INSECURE=true
|
||||
- ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
frontend:
|
||||
build:
|
||||
context: ../surfsense_web
|
||||
args:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
env_file:
|
||||
- ../surfsense_web/.env
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
electric:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
name: surfsense-postgres
|
||||
pgadmin_data:
|
||||
name: surfsense-pgadmin
|
||||
redis_data:
|
||||
name: surfsense-redis
|
||||
shared_temp:
|
||||
name: surfsense-shared-temp
|
||||
195
docker/docker-compose.yml
Normal file
195
docker/docker-compose.yml
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
# =============================================================================
|
||||
# SurfSense — Production Docker Compose
|
||||
# Docs: https://docs.surfsense.com/docs/docker-installation
|
||||
# =============================================================================
|
||||
# Usage:
|
||||
# 1. Copy .env.example to .env and edit the required values
|
||||
# 2. docker compose up -d
|
||||
# =============================================================================
|
||||
|
||||
name: surfsense
|
||||
|
||||
services:
|
||||
db:
|
||||
image: pgvector/pgvector:pg17
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
POSTGRES_USER: ${DB_USER:-surfsense}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-surfsense}
|
||||
POSTGRES_DB: ${DB_NAME:-surfsense}
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-surfsense} -d ${DB_NAME:-surfsense}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:8-alpine
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
backend:
|
||||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
|
||||
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_TASK_DEFAULT_QUEUE: surfsense
|
||||
PYTHONPATH: /app
|
||||
UVICORN_LOOP: asyncio
|
||||
UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3000}}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
# DAYTONA_SANDBOX_ENABLED: "TRUE"
|
||||
# DAYTONA_API_KEY: ${DAYTONA_API_KEY:-}
|
||||
# DAYTONA_API_URL: ${DAYTONA_API_URL:-https://app.daytona.io/api}
|
||||
# DAYTONA_TARGET: ${DAYTONA_TARGET:-us}
|
||||
SERVICE_ROLE: api
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
start_period: 200s
|
||||
|
||||
celery_worker:
|
||||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
|
||||
volumes:
|
||||
- shared_temp:/shared_tmp
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
|
||||
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_TASK_DEFAULT_QUEUE: surfsense
|
||||
PYTHONPATH: /app
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
SERVICE_ROLE: worker
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
backend:
|
||||
condition: service_healthy
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
restart: unless-stopped
|
||||
|
||||
celery_beat:
|
||||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}}
|
||||
CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_TASK_DEFAULT_QUEUE: surfsense
|
||||
PYTHONPATH: /app
|
||||
SERVICE_ROLE: beat
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
celery_worker:
|
||||
condition: service_started
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
restart: unless-stopped
|
||||
|
||||
# flower:
|
||||
# image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
|
||||
# ports:
|
||||
# - "${FLOWER_PORT:-5555}:5555"
|
||||
# env_file:
|
||||
# - .env
|
||||
# environment:
|
||||
# CELERY_BROKER_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
# CELERY_RESULT_BACKEND: ${REDIS_URL:-redis://redis:6379/0}
|
||||
# PYTHONPATH: /app
|
||||
# command: celery -A app.celery_app flower --port=5555
|
||||
# depends_on:
|
||||
# - redis
|
||||
# - celery_worker
|
||||
# restart: unless-stopped
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:1.4.10
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5133}:3000"
|
||||
environment:
|
||||
DATABASE_URL: ${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ELECTRIC_INSECURE: "true"
|
||||
ELECTRIC_WRITE_TO_PG_MODE: direct
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
frontend:
|
||||
image: ghcr.io/modsetter/surfsense-web:${SURFSENSE_VERSION:-latest}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
environment:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8000}}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:${ELECTRIC_PORT:-5133}}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
electric:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
name: surfsense-postgres
|
||||
redis_data:
|
||||
name: surfsense-redis
|
||||
shared_temp:
|
||||
name: surfsense-shared-temp
|
||||
|
|
@ -1,26 +1,9 @@
|
|||
#!/bin/sh
|
||||
# ============================================================================
|
||||
# Electric SQL User Initialization Script (docker-compose only)
|
||||
# ============================================================================
|
||||
# This script is ONLY used when running via docker-compose.
|
||||
#
|
||||
# How it works:
|
||||
# - docker-compose.yml mounts this script into the PostgreSQL container's
|
||||
# /docker-entrypoint-initdb.d/ directory
|
||||
# - PostgreSQL automatically executes scripts in that directory on first
|
||||
# container initialization
|
||||
#
|
||||
# For local PostgreSQL users (non-Docker), this script is NOT used.
|
||||
# Instead, the Electric user is created by Alembic migration 66
|
||||
# (66_add_notifications_table_and_electric_replication.py).
|
||||
#
|
||||
# Both approaches are idempotent (use IF NOT EXISTS), so running both
|
||||
# will not cause conflicts.
|
||||
# ============================================================================
|
||||
# Creates the Electric SQL replication user on first DB initialization.
|
||||
# Idempotent — safe to run alongside Alembic migration 66.
|
||||
|
||||
set -e
|
||||
|
||||
# Use environment variables with defaults
|
||||
ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
|
||||
ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
|
||||
|
||||
|
|
@ -43,7 +26,6 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-E
|
|||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
|
||||
|
||||
-- Create the publication for Electric SQL (if not exists)
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
350
docker/scripts/install.ps1
Normal file
350
docker/scripts/install.ps1
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
# =============================================================================
|
||||
# SurfSense — One-line Install Script (Windows / PowerShell)
|
||||
#
|
||||
#
|
||||
# Usage: irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex
|
||||
#
|
||||
# To pass flags, save and run locally:
|
||||
# .\install.ps1 -NoWatchtower
|
||||
# .\install.ps1 -WatchtowerInterval 3600
|
||||
#
|
||||
# Handles two cases automatically:
|
||||
# 1. Fresh install — no prior SurfSense data detected
|
||||
# 2. Migration from the legacy all-in-one container (surfsense-data volume)
|
||||
# Downloads and runs migrate-database.sh --yes, then restores the dump
|
||||
# into the new PostgreSQL 17 stack. The user runs one command for both.
|
||||
# =============================================================================
|
||||
|
||||
param(
|
||||
[switch]$NoWatchtower,
|
||||
[int]$WatchtowerInterval = 86400
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# ── Configuration ───────────────────────────────────────────────────────────
|
||||
|
||||
$RepoRaw = "https://raw.githubusercontent.com/MODSetter/SurfSense/main"
|
||||
$InstallDir = ".\surfsense"
|
||||
$OldVolume = "surfsense-data"
|
||||
$DumpFile = ".\surfsense_migration_backup.sql"
|
||||
$KeyFile = ".\surfsense_migration_secret.key"
|
||||
$MigrationDoneFile = "$InstallDir\.migration_done"
|
||||
$MigrationMode = $false
|
||||
$SetupWatchtower = -not $NoWatchtower
|
||||
$WatchtowerContainer = "watchtower"
|
||||
|
||||
# ── Output helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
function Write-Info { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Cyan -NoNewline; Write-Host $Msg }
|
||||
function Write-Ok { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Green -NoNewline; Write-Host $Msg }
|
||||
function Write-Warn { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Yellow -NoNewline; Write-Host $Msg }
|
||||
function Write-Step { param([string]$Msg) Write-Host "`n-- $Msg" -ForegroundColor Cyan }
|
||||
function Write-Err { param([string]$Msg) Write-Host "[SurfSense] ERROR: $Msg" -ForegroundColor Red; exit 1 }
|
||||
|
||||
function Invoke-NativeSafe {
|
||||
param([scriptblock]$Command)
|
||||
$previousErrorActionPreference = $ErrorActionPreference
|
||||
try {
|
||||
$ErrorActionPreference = 'Continue'
|
||||
& $Command
|
||||
} finally {
|
||||
$ErrorActionPreference = $previousErrorActionPreference
|
||||
}
|
||||
}
|
||||
|
||||
# ── Pre-flight checks ──────────────────────────────────────────────────────
|
||||
|
||||
Write-Step "Checking prerequisites"
|
||||
|
||||
if (-not (Get-Command docker -ErrorAction SilentlyContinue)) {
|
||||
Write-Err "Docker is not installed. Install Docker Desktop: https://docs.docker.com/desktop/install/windows-install/"
|
||||
}
|
||||
Write-Ok "Docker found."
|
||||
|
||||
Invoke-NativeSafe { docker info *>$null } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Docker daemon is not running. Please start Docker Desktop and try again."
|
||||
}
|
||||
Write-Ok "Docker daemon is running."
|
||||
|
||||
Invoke-NativeSafe { docker compose version *>$null } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Docker Compose is not available. It should be bundled with Docker Desktop."
|
||||
}
|
||||
Write-Ok "Docker Compose found."
|
||||
|
||||
# ── Wait-for-postgres helper ────────────────────────────────────────────────
|
||||
|
||||
function Wait-ForPostgres {
|
||||
param([string]$DbUser)
|
||||
$maxAttempts = 45
|
||||
$attempt = 0
|
||||
|
||||
Write-Info "Waiting for PostgreSQL to accept connections..."
|
||||
do {
|
||||
$attempt++
|
||||
if ($attempt -ge $maxAttempts) {
|
||||
Write-Err "PostgreSQL did not become ready after $($maxAttempts * 2) seconds.`nCheck logs: cd $InstallDir; docker compose logs db"
|
||||
}
|
||||
Start-Sleep -Seconds 2
|
||||
Push-Location $InstallDir
|
||||
Invoke-NativeSafe { docker compose exec -T db pg_isready -U $DbUser -q *>$null } | Out-Null
|
||||
$ready = $LASTEXITCODE -eq 0
|
||||
Pop-Location
|
||||
} while (-not $ready)
|
||||
|
||||
Write-Ok "PostgreSQL is ready."
|
||||
}
|
||||
|
||||
# ── Download files ──────────────────────────────────────────────────────────
|
||||
|
||||
Write-Step "Downloading SurfSense files"
|
||||
Write-Info "Installation directory: $InstallDir"
|
||||
|
||||
New-Item -ItemType Directory -Path "$InstallDir\scripts" -Force | Out-Null
|
||||
|
||||
$Files = @(
|
||||
@{ Src = "docker/docker-compose.yml"; Dest = "docker-compose.yml" }
|
||||
@{ Src = "docker/.env.example"; Dest = ".env.example" }
|
||||
@{ Src = "docker/postgresql.conf"; Dest = "postgresql.conf" }
|
||||
@{ Src = "docker/scripts/init-electric-user.sh"; Dest = "scripts/init-electric-user.sh" }
|
||||
@{ Src = "docker/scripts/migrate-database.ps1"; Dest = "scripts/migrate-database.ps1" }
|
||||
)
|
||||
|
||||
foreach ($f in $Files) {
|
||||
$destPath = Join-Path $InstallDir $f.Dest
|
||||
Write-Info "Downloading $($f.Dest)..."
|
||||
try {
|
||||
Invoke-WebRequest -Uri "$RepoRaw/$($f.Src)" -OutFile $destPath -UseBasicParsing
|
||||
} catch {
|
||||
Write-Err "Failed to download $($f.Dest). Check your internet connection and try again."
|
||||
}
|
||||
}
|
||||
|
||||
Write-Ok "All files downloaded to $InstallDir/"
|
||||
|
||||
# ── Legacy all-in-one detection ─────────────────────────────────────────────
|
||||
|
||||
$volumeList = Invoke-NativeSafe { docker volume ls --format '{{.Name}}' 2>$null }
|
||||
if (($volumeList -split "`n") -contains $OldVolume -and -not (Test-Path $MigrationDoneFile)) {
|
||||
$MigrationMode = $true
|
||||
|
||||
if (Test-Path $DumpFile) {
|
||||
Write-Step "Migration mode - using existing dump (skipping extraction)"
|
||||
Write-Info "Found existing dump: $DumpFile"
|
||||
Write-Info "Skipping data extraction - proceeding directly to restore."
|
||||
Write-Info "To force a fresh extraction, remove the dump first: Remove-Item $DumpFile"
|
||||
} else {
|
||||
Write-Step "Migration mode - legacy all-in-one container detected"
|
||||
Write-Warn "Volume '$OldVolume' found. Your data will be migrated automatically."
|
||||
Write-Warn "PostgreSQL is being upgraded from version 14 to 17."
|
||||
Write-Warn "Your original data will NOT be deleted."
|
||||
Write-Host ""
|
||||
Write-Info "Running data extraction (migrate-database.ps1 -Yes)..."
|
||||
Write-Info "Full extraction log: ./surfsense-migration.log"
|
||||
Write-Host ""
|
||||
|
||||
$migrateScript = Join-Path $InstallDir "scripts/migrate-database.ps1"
|
||||
& $migrateScript -Yes
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Data extraction failed. See ./surfsense-migration.log for details.`nYou can also run migrate-database.ps1 manually with custom flags."
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Ok "Data extraction complete. Proceeding with installation and restore."
|
||||
}
|
||||
}
|
||||
|
||||
# ── Set up .env ─────────────────────────────────────────────────────────────
|
||||
|
||||
Write-Step "Configuring environment"
|
||||
|
||||
$envPath = Join-Path $InstallDir ".env"
|
||||
$envExamplePath = Join-Path $InstallDir ".env.example"
|
||||
|
||||
if (-not (Test-Path $envPath)) {
|
||||
Copy-Item $envExamplePath $envPath
|
||||
|
||||
if ($MigrationMode -and (Test-Path $KeyFile)) {
|
||||
$SecretKey = (Get-Content $KeyFile -Raw).Trim()
|
||||
Write-Ok "Using SECRET_KEY recovered from legacy container."
|
||||
} else {
|
||||
$bytes = New-Object byte[] 32
|
||||
$rng = [System.Security.Cryptography.RNGCryptoServiceProvider]::new()
|
||||
$rng.GetBytes($bytes)
|
||||
$rng.Dispose()
|
||||
$SecretKey = [Convert]::ToBase64String($bytes)
|
||||
Write-Ok "Generated new random SECRET_KEY."
|
||||
}
|
||||
|
||||
$content = Get-Content $envPath -Raw
|
||||
$content = $content -replace 'SECRET_KEY=replace_me_with_a_random_string', "SECRET_KEY=$SecretKey"
|
||||
Set-Content -Path $envPath -Value $content -NoNewline
|
||||
|
||||
Write-Info "Created $envPath"
|
||||
} else {
|
||||
Write-Warn ".env already exists - keeping your existing configuration."
|
||||
}
|
||||
|
||||
# ── Start containers ────────────────────────────────────────────────────────
|
||||
|
||||
if ($MigrationMode) {
|
||||
$envContent = Get-Content $envPath
|
||||
$DbUser = ($envContent | Select-String '^DB_USER=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
|
||||
$DbPass = ($envContent | Select-String '^DB_PASSWORD=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
|
||||
$DbName = ($envContent | Select-String '^DB_NAME=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
|
||||
if (-not $DbUser) { $DbUser = "surfsense" }
|
||||
if (-not $DbPass) { $DbPass = "surfsense" }
|
||||
if (-not $DbName) { $DbName = "surfsense" }
|
||||
|
||||
Write-Step "Starting PostgreSQL 17"
|
||||
Push-Location $InstallDir
|
||||
Invoke-NativeSafe { docker compose up -d db } | Out-Null
|
||||
Pop-Location
|
||||
Wait-ForPostgres -DbUser $DbUser
|
||||
|
||||
Write-Step "Restoring database"
|
||||
if (-not (Test-Path $DumpFile)) {
|
||||
Write-Err "Dump file '$DumpFile' not found. The migration script may have failed."
|
||||
}
|
||||
$DumpFilePath = (Resolve-Path $DumpFile).Path
|
||||
Write-Info "Restoring dump into PostgreSQL 17 - this may take a while for large databases..."
|
||||
|
||||
$restoreErrFile = Join-Path $env:TEMP "surfsense_restore_err.log"
|
||||
Push-Location $InstallDir
|
||||
Invoke-NativeSafe { Get-Content -LiteralPath $DumpFilePath | docker compose exec -T -e "PGPASSWORD=$DbPass" db psql -U $DbUser -d $DbName 2>$restoreErrFile | Out-Null } | Out-Null
|
||||
Pop-Location
|
||||
|
||||
$fatalErrors = @()
|
||||
if (Test-Path $restoreErrFile) {
|
||||
$fatalErrors = Get-Content $restoreErrFile |
|
||||
Where-Object { $_ -match '^ERROR:' } |
|
||||
Where-Object { $_ -notmatch 'already exists' } |
|
||||
Where-Object { $_ -notmatch 'multiple primary keys' }
|
||||
}
|
||||
|
||||
if ($fatalErrors.Count -gt 0) {
|
||||
Write-Warn "Restore completed with errors (may be harmless pg_dump header noise):"
|
||||
$fatalErrors | ForEach-Object { Write-Host $_ }
|
||||
Write-Warn "If SurfSense behaves incorrectly, inspect manually."
|
||||
} else {
|
||||
Write-Ok "Database restored with no fatal errors."
|
||||
}
|
||||
|
||||
# Smoke test
|
||||
Push-Location $InstallDir
|
||||
$tableCount = (Invoke-NativeSafe { docker compose exec -T -e "PGPASSWORD=$DbPass" db psql -U $DbUser -d $DbName -t -c "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';" 2>$null }).Trim()
|
||||
Pop-Location
|
||||
|
||||
if (-not $tableCount -or $tableCount -eq "0") {
|
||||
Write-Warn "Smoke test: no tables found after restore."
|
||||
Write-Warn "The restore may have failed silently. Check: cd $InstallDir; docker compose logs db"
|
||||
} else {
|
||||
Write-Ok "Smoke test passed: $tableCount table(s) restored successfully."
|
||||
New-Item -Path $MigrationDoneFile -ItemType File -Force | Out-Null
|
||||
}
|
||||
|
||||
Write-Step "Starting all SurfSense services"
|
||||
Push-Location $InstallDir
|
||||
Invoke-NativeSafe { docker compose up -d }
|
||||
Pop-Location
|
||||
Write-Ok "All services started."
|
||||
|
||||
Remove-Item $KeyFile -ErrorAction SilentlyContinue
|
||||
|
||||
} else {
|
||||
Write-Step "Starting SurfSense"
|
||||
Push-Location $InstallDir
|
||||
Invoke-NativeSafe { docker compose up -d }
|
||||
Pop-Location
|
||||
Write-Ok "All services started."
|
||||
}
|
||||
|
||||
# ── Watchtower (auto-update) ────────────────────────────────────────────────
|
||||
|
||||
if ($SetupWatchtower) {
|
||||
$wtHours = [math]::Floor($WatchtowerInterval / 3600)
|
||||
Write-Step "Setting up Watchtower (auto-updates every ${wtHours}h)"
|
||||
|
||||
$wtState = Invoke-NativeSafe { docker inspect -f '{{.State.Running}}' $WatchtowerContainer 2>$null }
|
||||
if ($LASTEXITCODE -ne 0) { $wtState = "missing" }
|
||||
|
||||
if ($wtState -eq "true") {
|
||||
Write-Ok "Watchtower is already running - skipping."
|
||||
} else {
|
||||
if ($wtState -ne "missing") {
|
||||
Write-Info "Removing stopped Watchtower container..."
|
||||
Invoke-NativeSafe { docker rm -f $WatchtowerContainer *>$null } | Out-Null
|
||||
}
|
||||
Invoke-NativeSafe {
|
||||
docker run -d `
|
||||
--name $WatchtowerContainer `
|
||||
--restart unless-stopped `
|
||||
-v /var/run/docker.sock:/var/run/docker.sock `
|
||||
nickfedor/watchtower `
|
||||
--label-enable `
|
||||
--interval $WatchtowerInterval *>$null
|
||||
} | Out-Null
|
||||
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Ok "Watchtower started - labeled SurfSense containers will auto-update."
|
||||
} else {
|
||||
Write-Warn "Could not start Watchtower. You can set it up manually or use: docker compose pull; docker compose up -d"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Write-Info "Skipping Watchtower setup (-NoWatchtower flag)."
|
||||
}
|
||||
|
||||
# ── Done ────────────────────────────────────────────────────────────────────
|
||||
|
||||
Write-Host ""
|
||||
Write-Host @"
|
||||
|
||||
|
||||
.d8888b. .d888 .d8888b.
|
||||
d88P Y88b d88P" d88P Y88b
|
||||
Y88b. 888 Y88b.
|
||||
"Y888b. 888 888 888d888 888888 "Y888b. .d88b. 88888b. .d8888b .d88b.
|
||||
"Y88b. 888 888 888P" 888 "Y88b. d8P Y8b 888 "88b 88K d8P Y8b
|
||||
"888 888 888 888 888 "888 88888888 888 888 "Y8888b. 88888888
|
||||
Y88b d88P Y88b 888 888 888 Y88b d88P Y8b. 888 888 X88 Y8b.
|
||||
"Y8888P" "Y88888 888 888 "Y8888P" "Y8888 888 888 88888P' "Y8888
|
||||
|
||||
|
||||
"@ -ForegroundColor White
|
||||
|
||||
$versionDisplay = (Get-Content $envPath | Select-String '^SURFSENSE_VERSION=' | ForEach-Object { ($_ -split '=',2)[1].Trim('"') }) | Select-Object -First 1
|
||||
if (-not $versionDisplay) { $versionDisplay = "latest" }
|
||||
Write-Host " OSS Alternative to NotebookLM for Teams [$versionDisplay]" -ForegroundColor Yellow
|
||||
Write-Host ("=" * 62) -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
|
||||
Write-Info " Frontend: http://localhost:3000"
|
||||
Write-Info " Backend: http://localhost:8000"
|
||||
Write-Info " API Docs: http://localhost:8000/docs"
|
||||
Write-Info ""
|
||||
Write-Info " Config: $InstallDir\.env"
|
||||
Write-Info " Logs: cd $InstallDir; docker compose logs -f"
|
||||
Write-Info " Stop: cd $InstallDir; docker compose down"
|
||||
Write-Info " Update: cd $InstallDir; docker compose pull; docker compose up -d"
|
||||
Write-Info ""
|
||||
|
||||
if ($SetupWatchtower) {
|
||||
Write-Info " Watchtower: auto-updates every ${wtHours}h (stop: docker rm -f $WatchtowerContainer)"
|
||||
} else {
|
||||
Write-Warn " Watchtower skipped. For auto-updates, re-run without -NoWatchtower."
|
||||
}
|
||||
Write-Info ""
|
||||
|
||||
if ($MigrationMode) {
|
||||
Write-Warn " Migration complete! Open frontend and verify your data."
|
||||
Write-Warn " Once verified, clean up the legacy volume and migration files:"
|
||||
Write-Warn " docker volume rm $OldVolume"
|
||||
Write-Warn " Remove-Item $DumpFile"
|
||||
Write-Warn " Remove-Item $MigrationDoneFile"
|
||||
} else {
|
||||
Write-Warn " First startup may take a few minutes while images are pulled."
|
||||
Write-Warn " Edit $InstallDir\.env to configure API keys, OAuth, etc."
|
||||
}
|
||||
337
docker/scripts/install.sh
Normal file
337
docker/scripts/install.sh
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
#!/usr/bin/env bash
|
||||
# =============================================================================
|
||||
# SurfSense — One-line Install Script
|
||||
#
|
||||
#
|
||||
# Usage: curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash
|
||||
#
|
||||
# Flags:
|
||||
# --no-watchtower Skip automatic Watchtower setup
|
||||
# --watchtower-interval=SECS Check interval in seconds (default: 86400 = 24h)
|
||||
#
|
||||
# Handles two cases automatically:
|
||||
# 1. Fresh install — no prior SurfSense data detected
|
||||
# 2. Migration from the legacy all-in-one container (surfsense-data volume)
|
||||
# Downloads and runs migrate-database.sh --yes, then restores the dump
|
||||
# into the new PostgreSQL 17 stack. The user runs one command for both.
|
||||
#
|
||||
# If you used custom database credentials in the old all-in-one container, run
|
||||
# migrate-database.sh manually first (with --db-user / --db-password flags),
|
||||
# then re-run this script:
|
||||
# curl -fsSL .../docker/scripts/migrate-database.sh | bash -s -- --db-user X --db-password Y
|
||||
# =============================================================================
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
main() {
|
||||
|
||||
REPO_RAW="https://raw.githubusercontent.com/MODSetter/SurfSense/main"
|
||||
INSTALL_DIR="./surfsense"
|
||||
OLD_VOLUME="surfsense-data"
|
||||
DUMP_FILE="./surfsense_migration_backup.sql"
|
||||
KEY_FILE="./surfsense_migration_secret.key"
|
||||
MIGRATION_DONE_FILE="${INSTALL_DIR}/.migration_done"
|
||||
MIGRATION_MODE=false
|
||||
SETUP_WATCHTOWER=true
|
||||
WATCHTOWER_INTERVAL=86400
|
||||
WATCHTOWER_CONTAINER="watchtower"
|
||||
|
||||
# ── Parse flags ─────────────────────────────────────────────────────────────
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--no-watchtower) SETUP_WATCHTOWER=false ;;
|
||||
--watchtower-interval=*) WATCHTOWER_INTERVAL="${arg#*=}" ;;
|
||||
esac
|
||||
done
|
||||
|
||||
CYAN='\033[1;36m'
|
||||
YELLOW='\033[1;33m'
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
BOLD='\033[1m'
|
||||
NC='\033[0m'
|
||||
|
||||
info() { printf "${CYAN}[SurfSense]${NC} %s\n" "$1"; }
|
||||
success() { printf "${GREEN}[SurfSense]${NC} %s\n" "$1"; }
|
||||
warn() { printf "${YELLOW}[SurfSense]${NC} %s\n" "$1"; }
|
||||
error() { printf "${RED}[SurfSense]${NC} ERROR: %s\n" "$1" >&2; exit 1; }
|
||||
step() { printf "\n${BOLD}${CYAN}── %s${NC}\n" "$1"; }
|
||||
|
||||
# ── Pre-flight checks ────────────────────────────────────────────────────────
|
||||
|
||||
step "Checking prerequisites"
|
||||
|
||||
command -v docker >/dev/null 2>&1 \
|
||||
|| error "Docker is not installed. Install it at: https://docs.docker.com/get-docker/"
|
||||
success "Docker found."
|
||||
|
||||
docker info >/dev/null 2>&1 < /dev/null \
|
||||
|| error "Docker daemon is not running. Please start Docker and try again."
|
||||
success "Docker daemon is running."
|
||||
|
||||
if docker compose version >/dev/null 2>&1 < /dev/null; then
|
||||
DC="docker compose"
|
||||
elif command -v docker-compose >/dev/null 2>&1; then
|
||||
DC="docker-compose"
|
||||
else
|
||||
error "Docker Compose is not installed. Install it at: https://docs.docker.com/compose/install/"
|
||||
fi
|
||||
success "Docker Compose found ($DC)."
|
||||
|
||||
# ── Wait-for-postgres helper ─────────────────────────────────────────────────
|
||||
wait_for_pg() {
|
||||
local db_user="$1"
|
||||
local max_attempts=45
|
||||
local attempt=0
|
||||
|
||||
info "Waiting for PostgreSQL to accept connections..."
|
||||
until (cd "${INSTALL_DIR}" && ${DC} exec -T db pg_isready -U "${db_user}" -q 2>/dev/null) < /dev/null; do
|
||||
attempt=$((attempt + 1))
|
||||
if [[ $attempt -ge $max_attempts ]]; then
|
||||
error "PostgreSQL did not become ready after $((max_attempts * 2)) seconds.\nCheck logs: cd ${INSTALL_DIR} && ${DC} logs db"
|
||||
fi
|
||||
printf "."
|
||||
sleep 2
|
||||
done
|
||||
printf "\n"
|
||||
success "PostgreSQL is ready."
|
||||
}
|
||||
|
||||
# ── Download files ───────────────────────────────────────────────────────────
|
||||
|
||||
step "Downloading SurfSense files"
|
||||
info "Installation directory: ${INSTALL_DIR}"
|
||||
mkdir -p "${INSTALL_DIR}/scripts"
|
||||
|
||||
FILES=(
|
||||
"docker/docker-compose.yml:docker-compose.yml"
|
||||
"docker/.env.example:.env.example"
|
||||
"docker/postgresql.conf:postgresql.conf"
|
||||
"docker/scripts/init-electric-user.sh:scripts/init-electric-user.sh"
|
||||
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
|
||||
)
|
||||
|
||||
for entry in "${FILES[@]}"; do
|
||||
src="${entry%%:*}"
|
||||
dest="${entry##*:}"
|
||||
info "Downloading ${dest}..."
|
||||
curl -fsSL "${REPO_RAW}/${src}" -o "${INSTALL_DIR}/${dest}" \
|
||||
|| error "Failed to download ${dest}. Check your internet connection and try again."
|
||||
done
|
||||
|
||||
chmod +x "${INSTALL_DIR}/scripts/init-electric-user.sh"
|
||||
chmod +x "${INSTALL_DIR}/scripts/migrate-database.sh"
|
||||
success "All files downloaded to ${INSTALL_DIR}/"
|
||||
|
||||
# ── Legacy all-in-one detection ──────────────────────────────────────────────
|
||||
# Detect surfsense-data volume → migration mode.
|
||||
# If a dump already exists (from a previous partial run) skip extraction and
|
||||
# go straight to restore — this makes re-runs safe and idempotent.
|
||||
|
||||
if docker volume ls --format '{{.Name}}' 2>/dev/null < /dev/null | grep -q "^${OLD_VOLUME}$" \
|
||||
&& [[ ! -f "${MIGRATION_DONE_FILE}" ]]; then
|
||||
MIGRATION_MODE=true
|
||||
|
||||
if [[ -f "${DUMP_FILE}" ]]; then
|
||||
step "Migration mode — using existing dump (skipping extraction)"
|
||||
info "Found existing dump: ${DUMP_FILE}"
|
||||
info "Skipping data extraction — proceeding directly to restore."
|
||||
info "To force a fresh extraction, remove the dump first: rm ${DUMP_FILE}"
|
||||
else
|
||||
step "Migration mode — legacy all-in-one container detected"
|
||||
warn "Volume '${OLD_VOLUME}' found. Your data will be migrated automatically."
|
||||
warn "PostgreSQL is being upgraded from version 14 to 17."
|
||||
warn "Your original data will NOT be deleted."
|
||||
printf "\n"
|
||||
info "Running data extraction (migrate-database.sh --yes)..."
|
||||
info "Full extraction log: ./surfsense-migration.log"
|
||||
printf "\n"
|
||||
|
||||
# Run extraction non-interactively. On failure the error from
|
||||
# migrate-database.sh is printed and install.sh exits here.
|
||||
bash "${INSTALL_DIR}/scripts/migrate-database.sh" --yes < /dev/null \
|
||||
|| error "Data extraction failed. See ./surfsense-migration.log for details.\nYou can also run migrate-database.sh manually with custom flags:\n bash ${INSTALL_DIR}/scripts/migrate-database.sh --db-user X --db-password Y"
|
||||
|
||||
printf "\n"
|
||||
success "Data extraction complete. Proceeding with installation and restore."
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Set up .env ──────────────────────────────────────────────────────────────
|
||||
|
||||
step "Configuring environment"
|
||||
|
||||
if [ ! -f "${INSTALL_DIR}/.env" ]; then
|
||||
cp "${INSTALL_DIR}/.env.example" "${INSTALL_DIR}/.env"
|
||||
|
||||
if $MIGRATION_MODE && [[ -f "${KEY_FILE}" ]]; then
|
||||
SECRET_KEY=$(cat "${KEY_FILE}" | tr -d '[:space:]')
|
||||
success "Using SECRET_KEY recovered from legacy container."
|
||||
else
|
||||
SECRET_KEY=$(openssl rand -base64 32 2>/dev/null \
|
||||
|| head -c 32 /dev/urandom | base64 | tr -d '\n')
|
||||
success "Generated new random SECRET_KEY."
|
||||
fi
|
||||
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
sed -i '' "s|SECRET_KEY=replace_me_with_a_random_string|SECRET_KEY=${SECRET_KEY}|" "${INSTALL_DIR}/.env"
|
||||
else
|
||||
sed -i "s|SECRET_KEY=replace_me_with_a_random_string|SECRET_KEY=${SECRET_KEY}|" "${INSTALL_DIR}/.env"
|
||||
fi
|
||||
info "Created ${INSTALL_DIR}/.env"
|
||||
else
|
||||
warn ".env already exists — keeping your existing configuration."
|
||||
fi
|
||||
|
||||
# ── Start containers ─────────────────────────────────────────────────────────
|
||||
|
||||
if $MIGRATION_MODE; then
|
||||
# Read DB credentials from .env (fall back to defaults from docker-compose.yml)
|
||||
DB_USER=$(grep '^DB_USER=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
|
||||
DB_PASS=$(grep '^DB_PASSWORD=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
|
||||
DB_NAME=$(grep '^DB_NAME=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
|
||||
DB_USER="${DB_USER:-surfsense}"
|
||||
DB_PASS="${DB_PASS:-surfsense}"
|
||||
DB_NAME="${DB_NAME:-surfsense}"
|
||||
|
||||
step "Starting PostgreSQL 17"
|
||||
(cd "${INSTALL_DIR}" && ${DC} up -d db) < /dev/null
|
||||
wait_for_pg "${DB_USER}"
|
||||
|
||||
step "Restoring database"
|
||||
[[ -f "${DUMP_FILE}" ]] \
|
||||
|| error "Dump file '${DUMP_FILE}' not found. The migration script may have failed.\n Check: ./surfsense-migration.log\n Or run manually: bash ${INSTALL_DIR}/scripts/migrate-database.sh --yes"
|
||||
info "Restoring dump into PostgreSQL 17 — this may take a while for large databases..."
|
||||
|
||||
RESTORE_ERR="/tmp/surfsense_restore_err.log"
|
||||
(cd "${INSTALL_DIR}" && ${DC} exec -T \
|
||||
-e PGPASSWORD="${DB_PASS}" \
|
||||
db psql -U "${DB_USER}" -d "${DB_NAME}" \
|
||||
>/dev/null 2>"${RESTORE_ERR}") < "${DUMP_FILE}" || true
|
||||
|
||||
# Surface real errors; ignore benign "already exists" noise from pg_dump headers
|
||||
FATAL_ERRORS=$(grep -i "^ERROR:" "${RESTORE_ERR}" \
|
||||
| grep -iv "already exists" \
|
||||
| grep -iv "multiple primary keys" \
|
||||
|| true)
|
||||
|
||||
if [[ -n "${FATAL_ERRORS}" ]]; then
|
||||
warn "Restore completed with errors (may be harmless pg_dump header noise):"
|
||||
printf "%s\n" "${FATAL_ERRORS}"
|
||||
warn "If SurfSense behaves incorrectly, inspect manually:"
|
||||
warn " cd ${INSTALL_DIR} && ${DC} exec db psql -U ${DB_USER} -d ${DB_NAME} < ${DUMP_FILE}"
|
||||
else
|
||||
success "Database restored with no fatal errors."
|
||||
fi
|
||||
|
||||
# Smoke test — verify tables are present
|
||||
TABLE_COUNT=$(
|
||||
cd "${INSTALL_DIR}" && ${DC} exec -T \
|
||||
-e PGPASSWORD="${DB_PASS}" \
|
||||
db psql -U "${DB_USER}" -d "${DB_NAME}" -t \
|
||||
-c "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';" \
|
||||
2>/dev/null < /dev/null | tr -d ' \n' || echo "0"
|
||||
)
|
||||
if [[ "${TABLE_COUNT}" == "0" || -z "${TABLE_COUNT}" ]]; then
|
||||
warn "Smoke test: no tables found after restore."
|
||||
warn "The restore may have failed silently. Check: cd ${INSTALL_DIR} && ${DC} logs db"
|
||||
else
|
||||
success "Smoke test passed: ${TABLE_COUNT} table(s) restored successfully."
|
||||
touch "${MIGRATION_DONE_FILE}"
|
||||
fi
|
||||
|
||||
step "Starting all SurfSense services"
|
||||
(cd "${INSTALL_DIR}" && ${DC} up -d) < /dev/null
|
||||
success "All services started."
|
||||
|
||||
# Key file is no longer needed — SECRET_KEY is now in .env
|
||||
rm -f "${KEY_FILE}"
|
||||
|
||||
else
|
||||
step "Starting SurfSense"
|
||||
(cd "${INSTALL_DIR}" && ${DC} up -d) < /dev/null
|
||||
success "All services started."
|
||||
fi
|
||||
|
||||
# ── Watchtower (auto-update) ─────────────────────────────────────────────────
|
||||
|
||||
if $SETUP_WATCHTOWER; then
|
||||
step "Setting up Watchtower (auto-updates every $((WATCHTOWER_INTERVAL / 3600))h)"
|
||||
|
||||
WT_STATE=$(docker inspect -f '{{.State.Running}}' "${WATCHTOWER_CONTAINER}" 2>/dev/null < /dev/null || echo "missing")
|
||||
|
||||
if [[ "${WT_STATE}" == "true" ]]; then
|
||||
success "Watchtower is already running — skipping."
|
||||
else
|
||||
if [[ "${WT_STATE}" != "missing" ]]; then
|
||||
info "Removing stopped Watchtower container..."
|
||||
docker rm -f "${WATCHTOWER_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
fi
|
||||
docker run -d \
|
||||
--name "${WATCHTOWER_CONTAINER}" \
|
||||
--restart unless-stopped \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
nickfedor/watchtower \
|
||||
--label-enable \
|
||||
--interval "${WATCHTOWER_INTERVAL}" >/dev/null 2>&1 < /dev/null \
|
||||
&& success "Watchtower started — labeled SurfSense containers will auto-update." \
|
||||
|| warn "Could not start Watchtower. You can set it up manually or use: docker compose pull && docker compose up -d"
|
||||
fi
|
||||
else
|
||||
info "Skipping Watchtower setup (--no-watchtower flag)."
|
||||
fi
|
||||
|
||||
# ── Done ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
echo ""
|
||||
printf '\033[1;37m'
|
||||
cat << 'EOF'
|
||||
|
||||
|
||||
.d8888b. .d888 .d8888b.
|
||||
d88P Y88b d88P" d88P Y88b
|
||||
Y88b. 888 Y88b.
|
||||
"Y888b. 888 888 888d888 888888 "Y888b. .d88b. 88888b. .d8888b .d88b.
|
||||
"Y88b. 888 888 888P" 888 "Y88b. d8P Y8b 888 "88b 88K d8P Y8b
|
||||
"888 888 888 888 888 "888 88888888 888 888 "Y8888b. 88888888
|
||||
Y88b d88P Y88b 888 888 888 Y88b d88P Y8b. 888 888 X88 Y8b.
|
||||
"Y8888P" "Y88888 888 888 "Y8888P" "Y8888 888 888 88888P' "Y8888
|
||||
|
||||
|
||||
EOF
|
||||
_version_display=$(grep '^SURFSENSE_VERSION=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2 | tr -d '"' | head -1 || true)
|
||||
_version_display="${_version_display:-latest}"
|
||||
printf " OSS Alternative to NotebookLM for Teams ${YELLOW}[%s]${NC}\n" "${_version_display}"
|
||||
printf "${CYAN}══════════════════════════════════════════════════════════════${NC}\n\n"
|
||||
|
||||
info " Frontend: http://localhost:3000"
|
||||
info " Backend: http://localhost:8000"
|
||||
info " API Docs: http://localhost:8000/docs"
|
||||
info ""
|
||||
info " Config: ${INSTALL_DIR}/.env"
|
||||
info " Logs: cd ${INSTALL_DIR} && ${DC} logs -f"
|
||||
info " Stop: cd ${INSTALL_DIR} && ${DC} down"
|
||||
info " Update: cd ${INSTALL_DIR} && ${DC} pull && ${DC} up -d"
|
||||
info ""
|
||||
|
||||
if $SETUP_WATCHTOWER; then
|
||||
info " Watchtower: auto-updates every $((WATCHTOWER_INTERVAL / 3600))h (stop: docker rm -f ${WATCHTOWER_CONTAINER})"
|
||||
else
|
||||
warn " Watchtower skipped. For auto-updates, re-run without --no-watchtower."
|
||||
fi
|
||||
info ""
|
||||
|
||||
if $MIGRATION_MODE; then
|
||||
warn " Migration complete! Open frontend and verify your data."
|
||||
warn " Once verified, clean up the legacy volume and migration files:"
|
||||
warn " docker volume rm ${OLD_VOLUME}"
|
||||
warn " rm ${DUMP_FILE}"
|
||||
warn " rm ${MIGRATION_DONE_FILE}"
|
||||
else
|
||||
warn " First startup may take a few minutes while images are pulled."
|
||||
warn " Edit ${INSTALL_DIR}/.env to configure API keys, OAuth, etc."
|
||||
fi
|
||||
|
||||
} # end main()
|
||||
|
||||
main "$@"
|
||||
343
docker/scripts/migrate-database.ps1
Normal file
343
docker/scripts/migrate-database.ps1
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
# =============================================================================
|
||||
# SurfSense — Database Migration Script (Windows / PowerShell)
|
||||
#
|
||||
# Extracts data from the legacy all-in-one surfsense-data volume (PostgreSQL 14)
|
||||
# and saves it as a SQL dump + SECRET_KEY file ready for install.ps1 to restore.
|
||||
#
|
||||
# Usage:
|
||||
# .\migrate-database.ps1 [options]
|
||||
#
|
||||
# Options:
|
||||
# -DbUser USER Old PostgreSQL username (default: surfsense)
|
||||
# -DbPassword PASS Old PostgreSQL password (default: surfsense)
|
||||
# -DbName NAME Old PostgreSQL database (default: surfsense)
|
||||
# -Yes Skip all confirmation prompts
|
||||
#
|
||||
# Prerequisites:
|
||||
# - Docker Desktop installed and running
|
||||
# - The legacy surfsense-data volume must exist
|
||||
# - ~500 MB free disk space for the dump file
|
||||
#
|
||||
# What this script does:
|
||||
# 1. Stops any container using surfsense-data (to prevent corruption)
|
||||
# 2. Starts a temporary PG14 container against the old volume
|
||||
# 3. Dumps the database to .\surfsense_migration_backup.sql
|
||||
# 4. Recovers the SECRET_KEY to .\surfsense_migration_secret.key
|
||||
# 5. Exits — leaving installation to install.ps1
|
||||
#
|
||||
# What this script does NOT do:
|
||||
# - Delete the original surfsense-data volume (do this manually after verifying)
|
||||
# - Install the new SurfSense stack (install.ps1 handles that automatically)
|
||||
#
|
||||
# Note:
|
||||
# install.ps1 downloads and runs this script automatically when it detects the
|
||||
# legacy surfsense-data volume. You only need to run this script manually if
|
||||
# you have custom database credentials (-DbUser / -DbPassword / -DbName)
|
||||
# or if the automatic migration inside install.ps1 fails at the extraction step.
|
||||
# =============================================================================
|
||||
|
||||
param(
|
||||
[string]$DbUser = "surfsense",
|
||||
[string]$DbPassword = "surfsense",
|
||||
[string]$DbName = "surfsense",
|
||||
[switch]$Yes
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
$OldVolume = "surfsense-data"
|
||||
$TempContainer = "surfsense-pg14-migration"
|
||||
$DumpFile = ".\surfsense_migration_backup.sql"
|
||||
$KeyFile = ".\surfsense_migration_secret.key"
|
||||
$PG14Image = "pgvector/pgvector:pg14"
|
||||
$LogFile = ".\surfsense-migration.log"
|
||||
|
||||
# ── Output helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
function Write-Info { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Cyan -NoNewline; Write-Host $Msg }
|
||||
function Write-Ok { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Green -NoNewline; Write-Host $Msg }
|
||||
function Write-Warn { param([string]$Msg) Write-Host "[SurfSense] " -ForegroundColor Yellow -NoNewline; Write-Host $Msg }
|
||||
function Write-Step { param([string]$Step, [string]$Msg) Write-Host "`n-- Step ${Step}: $Msg" -ForegroundColor Cyan }
|
||||
function Write-Err { param([string]$Msg) Write-Host "[SurfSense] ERROR: $Msg" -ForegroundColor Red; exit 1 }
|
||||
|
||||
function Log { param([string]$Msg) Add-Content -Path $LogFile -Value $Msg }
|
||||
|
||||
function Invoke-NativeSafe {
|
||||
param([scriptblock]$Command)
|
||||
$previousErrorActionPreference = $ErrorActionPreference
|
||||
try {
|
||||
$ErrorActionPreference = 'Continue'
|
||||
& $Command
|
||||
} finally {
|
||||
$ErrorActionPreference = $previousErrorActionPreference
|
||||
}
|
||||
}
|
||||
|
||||
function Confirm-Action {
|
||||
param([string]$Prompt)
|
||||
if ($Yes) { return }
|
||||
$reply = Read-Host "[SurfSense] $Prompt [y/N]"
|
||||
if ($reply -notmatch '^[Yy]$') {
|
||||
Write-Warn "Aborted."
|
||||
exit 0
|
||||
}
|
||||
}
|
||||
|
||||
# ── Cleanup helper ───────────────────────────────────────────────────────────
|
||||
|
||||
function Remove-TempContainer {
|
||||
$containers = Invoke-NativeSafe { docker ps -a --format '{{.Names}}' 2>$null }
|
||||
if ($containers -and ($containers -split "`n") -contains $TempContainer) {
|
||||
Write-Info "Cleaning up temporary container '$TempContainer'..."
|
||||
Invoke-NativeSafe { docker stop $TempContainer *>$null } | Out-Null
|
||||
Invoke-NativeSafe { docker rm $TempContainer *>$null } | Out-Null
|
||||
}
|
||||
}
|
||||
|
||||
# Register cleanup on script exit
|
||||
Register-EngineEvent PowerShell.Exiting -Action {
|
||||
$containers = Invoke-NativeSafe { docker ps -a --format '{{.Names}}' 2>$null }
|
||||
if ($containers -and ($containers -split "`n") -contains "surfsense-pg14-migration") {
|
||||
Invoke-NativeSafe { docker stop "surfsense-pg14-migration" *>$null } | Out-Null
|
||||
Invoke-NativeSafe { docker rm "surfsense-pg14-migration" *>$null } | Out-Null
|
||||
}
|
||||
} | Out-Null
|
||||
|
||||
# ── Wait-for-postgres helper ────────────────────────────────────────────────
|
||||
|
||||
function Wait-ForPostgres {
|
||||
param(
|
||||
[string]$Container,
|
||||
[string]$User,
|
||||
[string]$Label = "PostgreSQL"
|
||||
)
|
||||
$maxAttempts = 45
|
||||
$attempt = 0
|
||||
|
||||
Write-Info "Waiting for $Label to accept connections..."
|
||||
do {
|
||||
$attempt++
|
||||
if ($attempt -ge $maxAttempts) {
|
||||
Write-Err "$Label did not become ready after $($maxAttempts * 2) seconds. Check: docker logs $Container"
|
||||
}
|
||||
Start-Sleep -Seconds 2
|
||||
Invoke-NativeSafe { docker exec $Container pg_isready -U $User -q 2>$null } | Out-Null
|
||||
} while ($LASTEXITCODE -ne 0)
|
||||
|
||||
Write-Ok "$Label is ready."
|
||||
}
|
||||
|
||||
Write-Info "Migrating data from legacy database (PostgreSQL 14 -> 17)"
|
||||
"Migration started at $(Get-Date)" | Out-File $LogFile
|
||||
|
||||
# ── Step 0: Pre-flight checks ───────────────────────────────────────────────
|
||||
|
||||
Write-Step "0" "Pre-flight checks"
|
||||
|
||||
if (-not (Get-Command docker -ErrorAction SilentlyContinue)) {
|
||||
Write-Err "Docker is not installed. Install Docker Desktop: https://docs.docker.com/desktop/install/windows-install/"
|
||||
}
|
||||
|
||||
Invoke-NativeSafe { docker info *>$null } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Docker daemon is not running. Please start Docker Desktop and try again."
|
||||
}
|
||||
|
||||
$volumeList = Invoke-NativeSafe { docker volume ls --format '{{.Name}}' 2>$null }
|
||||
if (-not (($volumeList -split "`n") -contains $OldVolume)) {
|
||||
Write-Err "Legacy volume '$OldVolume' not found. Are you sure you ran the old all-in-one SurfSense container?"
|
||||
}
|
||||
Write-Ok "Found legacy volume: $OldVolume"
|
||||
|
||||
$oldContainer = (Invoke-NativeSafe { docker ps --filter "volume=$OldVolume" --format '{{.Names}}' 2>$null } | Select-Object -First 1)
|
||||
if ($oldContainer) {
|
||||
Write-Warn "Container '$oldContainer' is running and using the '$OldVolume' volume."
|
||||
Write-Warn "It must be stopped before migration to prevent data file corruption."
|
||||
Confirm-Action "Stop '$oldContainer' now and proceed with data extraction?"
|
||||
Invoke-NativeSafe { docker stop $oldContainer *>$null } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Failed to stop '$oldContainer'. Try: docker stop $oldContainer"
|
||||
}
|
||||
Write-Ok "Container '$oldContainer' stopped."
|
||||
}
|
||||
|
||||
if (Test-Path $DumpFile) {
|
||||
Write-Warn "Dump file '$DumpFile' already exists."
|
||||
Write-Warn "If a previous extraction succeeded, just run install.ps1 now."
|
||||
Write-Warn "To re-extract, remove the file first: Remove-Item $DumpFile"
|
||||
Write-Err "Aborting to avoid overwriting an existing dump."
|
||||
}
|
||||
|
||||
$staleContainers = Invoke-NativeSafe { docker ps -a --format '{{.Names}}' 2>$null }
|
||||
if ($staleContainers -and ($staleContainers -split "`n") -contains $TempContainer) {
|
||||
Write-Warn "Stale migration container '$TempContainer' found - removing it."
|
||||
Invoke-NativeSafe { docker stop $TempContainer *>$null } | Out-Null
|
||||
Invoke-NativeSafe { docker rm $TempContainer *>$null } | Out-Null
|
||||
}
|
||||
|
||||
$drive = (Get-Item .).PSDrive
|
||||
$freeMB = [math]::Floor($drive.Free / 1MB)
|
||||
if ($freeMB -lt 500) {
|
||||
Write-Warn "Low disk space: $freeMB MB free. At least 500 MB recommended for the dump."
|
||||
Confirm-Action "Continue anyway?"
|
||||
} else {
|
||||
Write-Ok "Disk space: $freeMB MB free."
|
||||
}
|
||||
|
||||
Write-Ok "All pre-flight checks passed."
|
||||
|
||||
# ── Confirmation prompt ──────────────────────────────────────────────────────
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "Extraction plan:" -ForegroundColor White
|
||||
Write-Host " Source volume : " -NoNewline; Write-Host "$OldVolume" -ForegroundColor Yellow -NoNewline; Write-Host " (PG14 data at /data/postgres)"
|
||||
Write-Host " Old credentials : user=" -NoNewline; Write-Host "$DbUser" -ForegroundColor Yellow -NoNewline; Write-Host " db=" -NoNewline; Write-Host "$DbName" -ForegroundColor Yellow
|
||||
Write-Host " Dump saved to : " -NoNewline; Write-Host "$DumpFile" -ForegroundColor Yellow
|
||||
Write-Host " SECRET_KEY to : " -NoNewline; Write-Host "$KeyFile" -ForegroundColor Yellow
|
||||
Write-Host " Log file : " -NoNewline; Write-Host "$LogFile" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
Confirm-Action "Start data extraction? (Your original data will not be deleted or modified.)"
|
||||
|
||||
# ── Step 1: Start temporary PostgreSQL 14 container ──────────────────────────
|
||||
|
||||
Write-Step "1" "Starting temporary PostgreSQL 14 container"
|
||||
|
||||
Write-Info "Pulling $PG14Image..."
|
||||
Invoke-NativeSafe { docker pull $PG14Image *>$null } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Warn "Could not pull $PG14Image - using cached image if available."
|
||||
}
|
||||
|
||||
$dataUid = Invoke-NativeSafe { docker run --rm -v "${OldVolume}:/data" alpine stat -c '%u' /data/postgres 2>$null }
|
||||
if (-not $dataUid -or $dataUid -eq "0") {
|
||||
Write-Warn "Could not detect data directory UID - falling back to default (may chown files)."
|
||||
$userFlag = @()
|
||||
} else {
|
||||
Write-Info "Data directory owned by UID $dataUid - starting temp container as that user."
|
||||
$userFlag = @("--user", $dataUid)
|
||||
}
|
||||
|
||||
$dockerRunArgs = @(
|
||||
"run", "-d",
|
||||
"--name", $TempContainer,
|
||||
"-v", "${OldVolume}:/data",
|
||||
"-e", "PGDATA=/data/postgres",
|
||||
"-e", "POSTGRES_USER=$DbUser",
|
||||
"-e", "POSTGRES_PASSWORD=$DbPassword",
|
||||
"-e", "POSTGRES_DB=$DbName"
|
||||
) + $userFlag + @($PG14Image)
|
||||
|
||||
Invoke-NativeSafe { docker @dockerRunArgs *>$null } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Failed to start temporary PostgreSQL 14 container."
|
||||
}
|
||||
|
||||
Write-Ok "Temporary container '$TempContainer' started."
|
||||
Wait-ForPostgres -Container $TempContainer -User $DbUser -Label "PostgreSQL 14"
|
||||
|
||||
# ── Step 2: Dump the database ────────────────────────────────────────────────
|
||||
|
||||
Write-Step "2" "Dumping PostgreSQL 14 database"
|
||||
|
||||
Write-Info "Running pg_dump - this may take a while for large databases..."
|
||||
|
||||
$pgDumpErrFile = Join-Path $env:TEMP "surfsense_pgdump_err.log"
|
||||
Invoke-NativeSafe { docker exec -e "PGPASSWORD=$DbPassword" $TempContainer pg_dump -U $DbUser --no-password $DbName > $DumpFile 2>$pgDumpErrFile } | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
if (Test-Path $pgDumpErrFile) { Get-Content $pgDumpErrFile | Write-Host -ForegroundColor Red }
|
||||
Remove-TempContainer
|
||||
Write-Err "pg_dump failed. See above for details."
|
||||
}
|
||||
|
||||
if (-not (Test-Path $DumpFile) -or (Get-Item $DumpFile).Length -eq 0) {
|
||||
Remove-TempContainer
|
||||
Write-Err "Dump file '$DumpFile' is empty. Something went wrong with pg_dump."
|
||||
}
|
||||
|
||||
$dumpContent = (Get-Content $DumpFile -TotalCount 5) -join "`n"
|
||||
if ($dumpContent -notmatch "PostgreSQL database dump") {
|
||||
Remove-TempContainer
|
||||
Write-Err "Dump file does not contain a valid PostgreSQL dump header - the file may be corrupt."
|
||||
}
|
||||
|
||||
$dumpLines = (Get-Content $DumpFile | Measure-Object -Line).Lines
|
||||
if ($dumpLines -lt 10) {
|
||||
Remove-TempContainer
|
||||
Write-Err "Dump has only $dumpLines lines - suspiciously small. Aborting."
|
||||
}
|
||||
|
||||
$dumpSize = "{0:N1} MB" -f ((Get-Item $DumpFile).Length / 1MB)
|
||||
Write-Ok "Dump complete: $dumpSize ($dumpLines lines) -> $DumpFile"
|
||||
|
||||
Write-Info "Stopping temporary PostgreSQL 14 container..."
|
||||
Invoke-NativeSafe { docker stop $TempContainer *>$null } | Out-Null
|
||||
Invoke-NativeSafe { docker rm $TempContainer *>$null } | Out-Null
|
||||
Write-Ok "Temporary container removed."
|
||||
|
||||
# ── Step 3: Recover SECRET_KEY ───────────────────────────────────────────────
|
||||
|
||||
Write-Step "3" "Recovering SECRET_KEY"
|
||||
|
||||
$recoveredKey = ""
|
||||
|
||||
$keyCheck = Invoke-NativeSafe { docker run --rm -v "${OldVolume}:/data" alpine sh -c 'test -f /data/.secret_key && cat /data/.secret_key' 2>$null }
|
||||
if ($LASTEXITCODE -eq 0 -and $keyCheck) {
|
||||
$recoveredKey = $keyCheck.Trim()
|
||||
Write-Ok "Recovered SECRET_KEY from '$OldVolume'."
|
||||
} else {
|
||||
Write-Warn "No SECRET_KEY file found at /data/.secret_key in '$OldVolume'."
|
||||
Write-Warn "This means the all-in-one container was launched with SECRET_KEY set as an explicit env var."
|
||||
|
||||
if ($Yes) {
|
||||
$bytes = New-Object byte[] 32
|
||||
$rng = [System.Security.Cryptography.RNGCryptoServiceProvider]::new()
|
||||
$rng.GetBytes($bytes)
|
||||
$rng.Dispose()
|
||||
$recoveredKey = [Convert]::ToBase64String($bytes)
|
||||
Write-Warn "Non-interactive mode: generated a new SECRET_KEY automatically."
|
||||
Write-Warn "All active browser sessions will be logged out after migration."
|
||||
Write-Warn "To restore your original key, update SECRET_KEY in .\surfsense\.env afterwards."
|
||||
} else {
|
||||
Write-Warn "Enter the SECRET_KEY from your old container's environment"
|
||||
$recoveredKey = Read-Host "[SurfSense] (press Enter to generate a new one - existing sessions will be invalidated)"
|
||||
if (-not $recoveredKey) {
|
||||
$bytes = New-Object byte[] 32
|
||||
$rng = [System.Security.Cryptography.RNGCryptoServiceProvider]::new()
|
||||
$rng.GetBytes($bytes)
|
||||
$rng.Dispose()
|
||||
$recoveredKey = [Convert]::ToBase64String($bytes)
|
||||
Write-Warn "Generated a new SECRET_KEY. All active browser sessions will be logged out after migration."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Set-Content -Path $KeyFile -Value $recoveredKey -NoNewline
|
||||
Write-Ok "SECRET_KEY saved to $KeyFile"
|
||||
|
||||
# ── Done ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
Write-Host ""
|
||||
Write-Host ("=" * 62) -ForegroundColor Green
|
||||
Write-Host " Data extraction complete!" -ForegroundColor Green
|
||||
Write-Host ("=" * 62) -ForegroundColor Green
|
||||
Write-Host ""
|
||||
|
||||
Write-Ok "Dump file : $DumpFile ($dumpSize)"
|
||||
Write-Ok "Secret key: $KeyFile"
|
||||
Write-Host ""
|
||||
Write-Info "Next step - run install.ps1 from this same directory:"
|
||||
Write-Host ""
|
||||
Write-Host " irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.ps1 | iex" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
Write-Info "install.ps1 will detect the dump, restore your data into PostgreSQL 17,"
|
||||
Write-Info "and start the full SurfSense stack automatically."
|
||||
Write-Host ""
|
||||
Write-Warn "Keep both files until you have verified the migration:"
|
||||
Write-Warn " $DumpFile"
|
||||
Write-Warn " $KeyFile"
|
||||
Write-Warn "Full log saved to: $LogFile"
|
||||
Write-Host ""
|
||||
|
||||
Log "Migration extraction completed successfully at $(Get-Date)"
|
||||
335
docker/scripts/migrate-database.sh
Executable file
335
docker/scripts/migrate-database.sh
Executable file
|
|
@ -0,0 +1,335 @@
|
|||
#!/usr/bin/env bash
|
||||
# =============================================================================
|
||||
# SurfSense — Database Migration Script
|
||||
#
|
||||
# Extracts data from the legacy all-in-one surfsense-data volume (PostgreSQL 14)
|
||||
# and saves it as a SQL dump + SECRET_KEY file ready for install.sh to restore.
|
||||
#
|
||||
# Usage:
|
||||
# bash migrate-database.sh [options]
|
||||
#
|
||||
# Options:
|
||||
# --db-user USER Old PostgreSQL username (default: surfsense)
|
||||
# --db-password PASS Old PostgreSQL password (default: surfsense)
|
||||
# --db-name NAME Old PostgreSQL database (default: surfsense)
|
||||
# --yes / -y Skip all confirmation prompts
|
||||
# --help / -h Show this help
|
||||
#
|
||||
# Prerequisites:
|
||||
# - Docker installed and running
|
||||
# - The legacy surfsense-data volume must exist
|
||||
# - ~500 MB free disk space for the dump file
|
||||
#
|
||||
# What this script does:
|
||||
# 1. Stops any container using surfsense-data (to prevent corruption)
|
||||
# 2. Starts a temporary PG14 container against the old volume
|
||||
# 3. Dumps the database to ./surfsense_migration_backup.sql
|
||||
# 4. Recovers the SECRET_KEY to ./surfsense_migration_secret.key
|
||||
# 5. Exits — leaving installation to install.sh
|
||||
#
|
||||
# What this script does NOT do:
|
||||
# - Delete the original surfsense-data volume (do this manually after verifying)
|
||||
# - Install the new SurfSense stack (install.sh handles that automatically)
|
||||
#
|
||||
# Note:
|
||||
# install.sh downloads and runs this script automatically when it detects the
|
||||
# legacy surfsense-data volume. You only need to run this script manually if
|
||||
# you have custom database credentials (--db-user / --db-password / --db-name)
|
||||
# or if the automatic migration inside install.sh fails at the extraction step.
|
||||
# =============================================================================
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ── Colours ──────────────────────────────────────────────────────────────────
|
||||
CYAN='\033[1;36m'
|
||||
YELLOW='\033[1;33m'
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
BOLD='\033[1m'
|
||||
NC='\033[0m'
|
||||
|
||||
# ── Logging — tee everything to a log file ───────────────────────────────────
|
||||
LOG_FILE="./surfsense-migration.log"
|
||||
exec > >(tee -a "${LOG_FILE}") 2>&1
|
||||
|
||||
# ── Output helpers ────────────────────────────────────────────────────────────
|
||||
info() { printf "${CYAN}[SurfSense]${NC} %s\n" "$1"; }
|
||||
success() { printf "${GREEN}[SurfSense]${NC} %s\n" "$1"; }
|
||||
warn() { printf "${YELLOW}[SurfSense]${NC} %s\n" "$1"; }
|
||||
error() { printf "${RED}[SurfSense]${NC} ERROR: %s\n" "$1" >&2; exit 1; }
|
||||
step() { printf "\n${BOLD}${CYAN}── Step %s: %s${NC}\n" "$1" "$2"; }
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||
OLD_VOLUME="surfsense-data"
|
||||
TEMP_CONTAINER="surfsense-pg14-migration"
|
||||
DUMP_FILE="./surfsense_migration_backup.sql"
|
||||
KEY_FILE="./surfsense_migration_secret.key"
|
||||
PG14_IMAGE="pgvector/pgvector:pg14"
|
||||
|
||||
# ── Defaults ──────────────────────────────────────────────────────────────────
|
||||
OLD_DB_USER="surfsense"
|
||||
OLD_DB_PASSWORD="surfsense"
|
||||
OLD_DB_NAME="surfsense"
|
||||
AUTO_YES=false
|
||||
|
||||
# ── Argument parsing ──────────────────────────────────────────────────────────
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--db-user) OLD_DB_USER="$2"; shift 2 ;;
|
||||
--db-password) OLD_DB_PASSWORD="$2"; shift 2 ;;
|
||||
--db-name) OLD_DB_NAME="$2"; shift 2 ;;
|
||||
--yes|-y) AUTO_YES=true; shift ;;
|
||||
--help|-h)
|
||||
grep '^#' "$0" | grep -v '^#!/' | sed 's/^# \{0,1\}//'
|
||||
exit 0
|
||||
;;
|
||||
*) error "Unknown option: $1 — run with --help for usage." ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ── Confirmation helper ───────────────────────────────────────────────────────
|
||||
confirm() {
|
||||
if $AUTO_YES; then return 0; fi
|
||||
printf "${YELLOW}[SurfSense]${NC} %s [y/N] " "$1"
|
||||
read -r reply
|
||||
[[ "$reply" =~ ^[Yy]$ ]] || { warn "Aborted."; exit 0; }
|
||||
}
|
||||
|
||||
# ── Cleanup trap — always remove the temp container ──────────────────────────
|
||||
cleanup() {
|
||||
local exit_code=$?
|
||||
if docker ps -a --format '{{.Names}}' 2>/dev/null < /dev/null | grep -q "^${TEMP_CONTAINER}$"; then
|
||||
info "Cleaning up temporary container '${TEMP_CONTAINER}'..."
|
||||
docker stop "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
docker rm "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
fi
|
||||
if [[ $exit_code -ne 0 ]]; then
|
||||
printf "\n${RED}[SurfSense]${NC} Migration data extraction failed (exit code %s).\n" "${exit_code}" >&2
|
||||
printf "${RED}[SurfSense]${NC} Full log: %s\n" "${LOG_FILE}" >&2
|
||||
printf "${YELLOW}[SurfSense]${NC} Your original data in '${OLD_VOLUME}' is untouched.\n" >&2
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
# ── Wait-for-postgres helper ──────────────────────────────────────────────────
|
||||
wait_for_pg() {
|
||||
local container="$1"
|
||||
local user="$2"
|
||||
local label="${3:-PostgreSQL}"
|
||||
local max_attempts=45
|
||||
local attempt=0
|
||||
|
||||
info "Waiting for ${label} to accept connections..."
|
||||
until docker exec "${container}" pg_isready -U "${user}" -q 2>/dev/null < /dev/null; do
|
||||
attempt=$((attempt + 1))
|
||||
if [[ $attempt -ge $max_attempts ]]; then
|
||||
error "${label} did not become ready after $((max_attempts * 2)) seconds. Check: docker logs ${container}"
|
||||
fi
|
||||
printf "."
|
||||
sleep 2
|
||||
done
|
||||
printf "\n"
|
||||
success "${label} is ready."
|
||||
}
|
||||
|
||||
info "Migrating data from legacy database (PostgreSQL 14 → 17)"
|
||||
|
||||
# ── Step 0: Pre-flight checks ─────────────────────────────────────────────────
|
||||
step "0" "Pre-flight checks"
|
||||
|
||||
# Docker CLI
|
||||
command -v docker >/dev/null 2>&1 \
|
||||
|| error "Docker is not installed. Install it at: https://docs.docker.com/get-docker/"
|
||||
|
||||
# Docker daemon
|
||||
docker info >/dev/null 2>&1 < /dev/null \
|
||||
|| error "Docker daemon is not running. Please start Docker and try again."
|
||||
|
||||
# Old volume must exist
|
||||
docker volume ls --format '{{.Name}}' < /dev/null | grep -q "^${OLD_VOLUME}$" \
|
||||
|| error "Legacy volume '${OLD_VOLUME}' not found.\n Are you sure you ran the old all-in-one SurfSense container?"
|
||||
success "Found legacy volume: ${OLD_VOLUME}"
|
||||
|
||||
# Detect and stop any container currently using the old volume
|
||||
# (mounting a live PG volume into a second container causes the new container's
|
||||
# entrypoint to chown the data files, breaking the running container's access)
|
||||
OLD_CONTAINER=$(docker ps --filter "volume=${OLD_VOLUME}" --format '{{.Names}}' < /dev/null | head -n1 || true)
|
||||
if [[ -n "${OLD_CONTAINER}" ]]; then
|
||||
warn "Container '${OLD_CONTAINER}' is running and using the '${OLD_VOLUME}' volume."
|
||||
warn "It must be stopped before migration to prevent data file corruption."
|
||||
confirm "Stop '${OLD_CONTAINER}' now and proceed with data extraction?"
|
||||
docker stop "${OLD_CONTAINER}" >/dev/null 2>&1 < /dev/null \
|
||||
|| error "Failed to stop '${OLD_CONTAINER}'. Try: docker stop ${OLD_CONTAINER}"
|
||||
success "Container '${OLD_CONTAINER}' stopped."
|
||||
fi
|
||||
|
||||
# Bail out if a dump already exists — don't overwrite a previous successful run
|
||||
if [[ -f "${DUMP_FILE}" ]]; then
|
||||
warn "Dump file '${DUMP_FILE}' already exists."
|
||||
warn "If a previous extraction succeeded, just run install.sh now."
|
||||
warn "To re-extract, remove the file first: rm ${DUMP_FILE}"
|
||||
error "Aborting to avoid overwriting an existing dump."
|
||||
fi
|
||||
|
||||
# Clean up any stale temp container from a previous failed run
|
||||
if docker ps -a --format '{{.Names}}' < /dev/null | grep -q "^${TEMP_CONTAINER}$"; then
|
||||
warn "Stale migration container '${TEMP_CONTAINER}' found — removing it."
|
||||
docker stop "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
docker rm "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
fi
|
||||
|
||||
# Disk space (warn if < 500 MB free)
|
||||
if command -v df >/dev/null 2>&1; then
|
||||
FREE_KB=$(df -k . | awk 'NR==2 {print $4}')
|
||||
FREE_MB=$(( FREE_KB / 1024 ))
|
||||
if [[ $FREE_MB -lt 500 ]]; then
|
||||
warn "Low disk space: ${FREE_MB} MB free. At least 500 MB recommended for the dump."
|
||||
confirm "Continue anyway?"
|
||||
else
|
||||
success "Disk space: ${FREE_MB} MB free."
|
||||
fi
|
||||
fi
|
||||
|
||||
success "All pre-flight checks passed."
|
||||
|
||||
# ── Confirmation prompt ───────────────────────────────────────────────────────
|
||||
printf "\n${BOLD}Extraction plan:${NC}\n"
|
||||
printf " Source volume : ${YELLOW}%s${NC} (PG14 data at /data/postgres)\n" "${OLD_VOLUME}"
|
||||
printf " Old credentials : user=${YELLOW}%s${NC} db=${YELLOW}%s${NC}\n" "${OLD_DB_USER}" "${OLD_DB_NAME}"
|
||||
printf " Dump saved to : ${YELLOW}%s${NC}\n" "${DUMP_FILE}"
|
||||
printf " SECRET_KEY to : ${YELLOW}%s${NC}\n" "${KEY_FILE}"
|
||||
printf " Log file : ${YELLOW}%s${NC}\n\n" "${LOG_FILE}"
|
||||
confirm "Start data extraction? (Your original data will not be deleted or modified.)"
|
||||
|
||||
# ── Step 1: Start temporary PostgreSQL 14 container ──────────────────────────
|
||||
step "1" "Starting temporary PostgreSQL 14 container"
|
||||
|
||||
info "Pulling ${PG14_IMAGE}..."
|
||||
docker pull "${PG14_IMAGE}" >/dev/null 2>&1 < /dev/null \
|
||||
|| warn "Could not pull ${PG14_IMAGE} — using cached image if available."
|
||||
|
||||
# Detect the UID that owns the existing data files and run the temp container
|
||||
# as that user. This prevents the official postgres image entrypoint from
|
||||
# running as root and doing `chown -R postgres /data/postgres`, which would
|
||||
# re-own the files to UID 999 and break any subsequent access by the original
|
||||
# container's postgres process (which may run as a different UID).
|
||||
DATA_UID=$(docker run --rm -v "${OLD_VOLUME}:/data" alpine \
|
||||
stat -c '%u' /data/postgres 2>/dev/null < /dev/null || echo "")
|
||||
if [[ -z "${DATA_UID}" || "${DATA_UID}" == "0" ]]; then
|
||||
warn "Could not detect data directory UID — falling back to default (may chown files)."
|
||||
USER_FLAG=""
|
||||
else
|
||||
info "Data directory owned by UID ${DATA_UID} — starting temp container as that user."
|
||||
USER_FLAG="--user ${DATA_UID}"
|
||||
fi
|
||||
|
||||
docker run -d \
|
||||
--name "${TEMP_CONTAINER}" \
|
||||
-v "${OLD_VOLUME}:/data" \
|
||||
-e PGDATA=/data/postgres \
|
||||
-e POSTGRES_USER="${OLD_DB_USER}" \
|
||||
-e POSTGRES_PASSWORD="${OLD_DB_PASSWORD}" \
|
||||
-e POSTGRES_DB="${OLD_DB_NAME}" \
|
||||
${USER_FLAG} \
|
||||
"${PG14_IMAGE}" >/dev/null < /dev/null
|
||||
|
||||
success "Temporary container '${TEMP_CONTAINER}' started."
|
||||
wait_for_pg "${TEMP_CONTAINER}" "${OLD_DB_USER}" "PostgreSQL 14"
|
||||
|
||||
# ── Step 2: Dump the database ─────────────────────────────────────────────────
|
||||
step "2" "Dumping PostgreSQL 14 database"
|
||||
|
||||
info "Running pg_dump — this may take a while for large databases..."
|
||||
|
||||
if ! docker exec \
|
||||
-e PGPASSWORD="${OLD_DB_PASSWORD}" \
|
||||
"${TEMP_CONTAINER}" \
|
||||
pg_dump -U "${OLD_DB_USER}" --no-password "${OLD_DB_NAME}" \
|
||||
> "${DUMP_FILE}" 2>/tmp/surfsense_pgdump_err < /dev/null; then
|
||||
cat /tmp/surfsense_pgdump_err >&2
|
||||
error "pg_dump failed. See above for details."
|
||||
fi
|
||||
|
||||
# Validate: non-empty file
|
||||
[[ -s "${DUMP_FILE}" ]] \
|
||||
|| error "Dump file '${DUMP_FILE}' is empty. Something went wrong with pg_dump."
|
||||
|
||||
# Validate: looks like a real PG dump
|
||||
grep -q "PostgreSQL database dump" "${DUMP_FILE}" \
|
||||
|| error "Dump file does not contain a valid PostgreSQL dump header — the file may be corrupt."
|
||||
|
||||
# Validate: sanity-check line count
|
||||
DUMP_LINES=$(wc -l < "${DUMP_FILE}" | tr -d ' ')
|
||||
[[ $DUMP_LINES -ge 10 ]] \
|
||||
|| error "Dump has only ${DUMP_LINES} lines — suspiciously small. Aborting."
|
||||
|
||||
DUMP_SIZE=$(du -sh "${DUMP_FILE}" 2>/dev/null | cut -f1)
|
||||
success "Dump complete: ${DUMP_SIZE} (${DUMP_LINES} lines) → ${DUMP_FILE}"
|
||||
|
||||
# Stop the temp container (trap will also handle it on unexpected exit)
|
||||
info "Stopping temporary PostgreSQL 14 container..."
|
||||
docker stop "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
docker rm "${TEMP_CONTAINER}" >/dev/null 2>&1 < /dev/null || true
|
||||
success "Temporary container removed."
|
||||
|
||||
# ── Step 3: Recover SECRET_KEY ────────────────────────────────────────────────
|
||||
step "3" "Recovering SECRET_KEY"
|
||||
|
||||
RECOVERED_KEY=""
|
||||
|
||||
if docker run --rm -v "${OLD_VOLUME}:/data" alpine \
|
||||
sh -c 'test -f /data/.secret_key && cat /data/.secret_key' \
|
||||
2>/dev/null < /dev/null | grep -q .; then
|
||||
RECOVERED_KEY=$(
|
||||
docker run --rm -v "${OLD_VOLUME}:/data" alpine \
|
||||
cat /data/.secret_key 2>/dev/null < /dev/null | tr -d '[:space:]'
|
||||
)
|
||||
success "Recovered SECRET_KEY from '${OLD_VOLUME}'."
|
||||
else
|
||||
warn "No SECRET_KEY file found at /data/.secret_key in '${OLD_VOLUME}'."
|
||||
warn "This means the all-in-one container was launched with SECRET_KEY set as an explicit env var."
|
||||
if $AUTO_YES; then
|
||||
# Non-interactive (called from install.sh) — auto-generate rather than hanging on read
|
||||
RECOVERED_KEY=$(openssl rand -base64 32 2>/dev/null \
|
||||
|| head -c 32 /dev/urandom | base64 | tr -d '\n')
|
||||
warn "Non-interactive mode: generated a new SECRET_KEY automatically."
|
||||
warn "All active browser sessions will be logged out after migration."
|
||||
warn "To restore your original key, update SECRET_KEY in ./surfsense/.env afterwards."
|
||||
else
|
||||
printf "${YELLOW}[SurfSense]${NC} Enter the SECRET_KEY from your old container's environment\n"
|
||||
printf "${YELLOW}[SurfSense]${NC} (press Enter to generate a new one — existing sessions will be invalidated): "
|
||||
read -r RECOVERED_KEY
|
||||
if [[ -z "${RECOVERED_KEY}" ]]; then
|
||||
RECOVERED_KEY=$(openssl rand -base64 32 2>/dev/null \
|
||||
|| head -c 32 /dev/urandom | base64 | tr -d '\n')
|
||||
warn "Generated a new SECRET_KEY. All active browser sessions will be logged out after migration."
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Save SECRET_KEY to a file for install.sh to pick up
|
||||
printf '%s' "${RECOVERED_KEY}" > "${KEY_FILE}"
|
||||
success "SECRET_KEY saved to ${KEY_FILE}"
|
||||
|
||||
# ── Done ──────────────────────────────────────────────────────────────────────
|
||||
printf "\n${GREEN}${BOLD}"
|
||||
printf "══════════════════════════════════════════════════════════════\n"
|
||||
printf " Data extraction complete!\n"
|
||||
printf "══════════════════════════════════════════════════════════════\n"
|
||||
printf "${NC}\n"
|
||||
|
||||
success "Dump file : ${DUMP_FILE} (${DUMP_SIZE})"
|
||||
success "Secret key: ${KEY_FILE}"
|
||||
printf "\n"
|
||||
info "Next step — run install.sh from this same directory:"
|
||||
printf "\n"
|
||||
printf "${CYAN} curl -fsSL https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/install.sh | bash${NC}\n"
|
||||
printf "\n"
|
||||
info "install.sh will detect the dump, restore your data into PostgreSQL 17,"
|
||||
info "and start the full SurfSense stack automatically."
|
||||
printf "\n"
|
||||
warn "Keep both files until you have verified the migration:"
|
||||
warn " ${DUMP_FILE}"
|
||||
warn " ${KEY_FILE}"
|
||||
warn "Full log saved to: ${LOG_FILE}"
|
||||
printf "\n"
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "==========================================="
|
||||
echo " 🏄 SurfSense All-in-One Container"
|
||||
echo "==========================================="
|
||||
|
||||
# Create log directory
|
||||
mkdir -p /var/log/supervisor
|
||||
|
||||
# ================================================
|
||||
# Ensure data directory exists
|
||||
# ================================================
|
||||
mkdir -p /data
|
||||
|
||||
# ================================================
|
||||
# Generate SECRET_KEY if not provided
|
||||
# ================================================
|
||||
if [ -z "$SECRET_KEY" ]; then
|
||||
# Generate a random secret key and persist it
|
||||
if [ -f /data/.secret_key ]; then
|
||||
export SECRET_KEY=$(cat /data/.secret_key)
|
||||
echo "✅ Using existing SECRET_KEY from persistent storage"
|
||||
else
|
||||
export SECRET_KEY=$(python3 -c "import secrets; print(secrets.token_urlsafe(32))")
|
||||
echo "$SECRET_KEY" > /data/.secret_key
|
||||
chmod 600 /data/.secret_key
|
||||
echo "✅ Generated new SECRET_KEY (saved for persistence)"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ================================================
|
||||
# Set default TTS/STT services if not provided
|
||||
# ================================================
|
||||
if [ -z "$TTS_SERVICE" ]; then
|
||||
export TTS_SERVICE="local/kokoro"
|
||||
echo "✅ Using default TTS_SERVICE: local/kokoro"
|
||||
fi
|
||||
|
||||
if [ -z "$STT_SERVICE" ]; then
|
||||
export STT_SERVICE="local/base"
|
||||
echo "✅ Using default STT_SERVICE: local/base"
|
||||
fi
|
||||
|
||||
# ================================================
|
||||
# Set Electric SQL configuration
|
||||
# ================================================
|
||||
export ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
|
||||
export ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
|
||||
if [ -z "$ELECTRIC_DATABASE_URL" ]; then
|
||||
export ELECTRIC_DATABASE_URL="postgresql://${ELECTRIC_DB_USER}:${ELECTRIC_DB_PASSWORD}@localhost:5432/${POSTGRES_DB:-surfsense}?sslmode=disable"
|
||||
echo "✅ Electric SQL URL configured dynamically"
|
||||
else
|
||||
# Ensure sslmode=disable is in the URL if not already present
|
||||
if [[ "$ELECTRIC_DATABASE_URL" != *"sslmode="* ]]; then
|
||||
# Add sslmode=disable (handle both cases: with or without existing query params)
|
||||
if [[ "$ELECTRIC_DATABASE_URL" == *"?"* ]]; then
|
||||
export ELECTRIC_DATABASE_URL="${ELECTRIC_DATABASE_URL}&sslmode=disable"
|
||||
else
|
||||
export ELECTRIC_DATABASE_URL="${ELECTRIC_DATABASE_URL}?sslmode=disable"
|
||||
fi
|
||||
fi
|
||||
echo "✅ Electric SQL URL configured from environment"
|
||||
fi
|
||||
|
||||
# Set Electric SQL port
|
||||
export ELECTRIC_PORT="${ELECTRIC_PORT:-5133}"
|
||||
export PORT="${ELECTRIC_PORT}"
|
||||
|
||||
# ================================================
|
||||
# Initialize PostgreSQL if needed
|
||||
# ================================================
|
||||
if [ ! -f /data/postgres/PG_VERSION ]; then
|
||||
echo "📦 Initializing PostgreSQL database..."
|
||||
|
||||
# Initialize PostgreSQL data directory
|
||||
chown -R postgres:postgres /data/postgres
|
||||
chmod 700 /data/postgres
|
||||
|
||||
# Initialize with UTF8 encoding (required for proper text handling)
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/initdb -D /data/postgres --encoding=UTF8 --locale=C.UTF-8"
|
||||
|
||||
# Configure PostgreSQL for connections
|
||||
echo "host all all 0.0.0.0/0 md5" >> /data/postgres/pg_hba.conf
|
||||
echo "local all all trust" >> /data/postgres/pg_hba.conf
|
||||
echo "listen_addresses='*'" >> /data/postgres/postgresql.conf
|
||||
|
||||
# Enable logical replication for Electric SQL
|
||||
echo "wal_level = logical" >> /data/postgres/postgresql.conf
|
||||
echo "max_replication_slots = 10" >> /data/postgres/postgresql.conf
|
||||
echo "max_wal_senders = 10" >> /data/postgres/postgresql.conf
|
||||
|
||||
# Start PostgreSQL temporarily to create database and user
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres -l /tmp/postgres_init.log start"
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
sleep 5
|
||||
|
||||
# Create user and database
|
||||
su - postgres -c "psql -c \"CREATE USER ${POSTGRES_USER:-surfsense} WITH PASSWORD '${POSTGRES_PASSWORD:-surfsense}' SUPERUSER;\""
|
||||
su - postgres -c "psql -c \"CREATE DATABASE ${POSTGRES_DB:-surfsense} OWNER ${POSTGRES_USER:-surfsense};\""
|
||||
|
||||
# Enable pgvector extension
|
||||
su - postgres -c "psql -d ${POSTGRES_DB:-surfsense} -c 'CREATE EXTENSION IF NOT EXISTS vector;'"
|
||||
|
||||
# Create Electric SQL replication user (idempotent - uses IF NOT EXISTS)
|
||||
echo "📡 Creating Electric SQL replication user..."
|
||||
su - postgres -c "psql -d ${POSTGRES_DB:-surfsense} <<-EOSQL
|
||||
DO \\\$\\\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '${ELECTRIC_DB_USER}') THEN
|
||||
CREATE USER ${ELECTRIC_DB_USER} WITH REPLICATION PASSWORD '${ELECTRIC_DB_PASSWORD}';
|
||||
END IF;
|
||||
END
|
||||
\\\$\\\$;
|
||||
|
||||
GRANT CONNECT ON DATABASE ${POSTGRES_DB:-surfsense} TO ${ELECTRIC_DB_USER};
|
||||
GRANT USAGE ON SCHEMA public TO ${ELECTRIC_DB_USER};
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO ${ELECTRIC_DB_USER};
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO ${ELECTRIC_DB_USER};
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO ${ELECTRIC_DB_USER};
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO ${ELECTRIC_DB_USER};
|
||||
|
||||
-- Create the publication for Electric SQL (if not exists)
|
||||
DO \\\$\\\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
CREATE PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
END
|
||||
\\\$\\\$;
|
||||
EOSQL"
|
||||
echo "✅ Electric SQL user '${ELECTRIC_DB_USER}' created"
|
||||
|
||||
# Stop temporary PostgreSQL
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres stop"
|
||||
|
||||
echo "✅ PostgreSQL initialized successfully"
|
||||
else
|
||||
echo "✅ PostgreSQL data directory already exists"
|
||||
fi
|
||||
|
||||
# ================================================
|
||||
# Initialize Redis data directory
|
||||
# ================================================
|
||||
mkdir -p /data/redis
|
||||
chmod 755 /data/redis
|
||||
echo "✅ Redis data directory ready"
|
||||
|
||||
# ================================================
|
||||
# Copy frontend build to runtime location
|
||||
# ================================================
|
||||
if [ -d /app/frontend/.next/standalone ]; then
|
||||
cp -r /app/frontend/.next/standalone/* /app/frontend/ 2>/dev/null || true
|
||||
cp -r /app/frontend/.next/static /app/frontend/.next/static 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# ================================================
|
||||
# Runtime Environment Variable Replacement
|
||||
# ================================================
|
||||
# Next.js NEXT_PUBLIC_* vars are baked in at build time.
|
||||
# This replaces placeholder values with actual runtime env vars.
|
||||
echo "🔧 Applying runtime environment configuration..."
|
||||
|
||||
# Set defaults if not provided
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL="${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}"
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE="${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}"
|
||||
NEXT_PUBLIC_ETL_SERVICE="${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}"
|
||||
NEXT_PUBLIC_ELECTRIC_URL="${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}"
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE="${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}"
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE="${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}"
|
||||
|
||||
# Replace placeholders in all JS files
|
||||
find /app/frontend -type f \( -name "*.js" -o -name "*.json" \) -exec sed -i \
|
||||
-e "s|__NEXT_PUBLIC_FASTAPI_BACKEND_URL__|${NEXT_PUBLIC_FASTAPI_BACKEND_URL}|g" \
|
||||
-e "s|__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__|${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}|g" \
|
||||
-e "s|__NEXT_PUBLIC_ETL_SERVICE__|${NEXT_PUBLIC_ETL_SERVICE}|g" \
|
||||
-e "s|__NEXT_PUBLIC_ELECTRIC_URL__|${NEXT_PUBLIC_ELECTRIC_URL}|g" \
|
||||
-e "s|__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__|${NEXT_PUBLIC_ELECTRIC_AUTH_MODE}|g" \
|
||||
-e "s|__NEXT_PUBLIC_DEPLOYMENT_MODE__|${NEXT_PUBLIC_DEPLOYMENT_MODE}|g" \
|
||||
{} +
|
||||
|
||||
echo "✅ Environment configuration applied"
|
||||
echo " Backend URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
|
||||
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
|
||||
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
|
||||
echo " Electric URL: ${NEXT_PUBLIC_ELECTRIC_URL}"
|
||||
echo " Deployment Mode: ${NEXT_PUBLIC_DEPLOYMENT_MODE}"
|
||||
|
||||
# ================================================
|
||||
# Run database migrations
|
||||
# ================================================
|
||||
run_migrations() {
|
||||
echo "🔄 Running database migrations..."
|
||||
|
||||
# Start PostgreSQL temporarily for migrations
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres -l /tmp/postgres_migrate.log start"
|
||||
sleep 5
|
||||
|
||||
# Start Redis temporarily for migrations (some might need it)
|
||||
redis-server --dir /data/redis --daemonize yes
|
||||
sleep 2
|
||||
|
||||
# Run alembic migrations
|
||||
cd /app/backend
|
||||
alembic upgrade head || echo "⚠️ Migrations may have already been applied"
|
||||
|
||||
# Stop temporary services
|
||||
redis-cli shutdown || true
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres stop"
|
||||
|
||||
echo "✅ Database migrations complete"
|
||||
}
|
||||
|
||||
# Always run migrations on startup - alembic upgrade head is safe to run
|
||||
# every time. It only applies pending migrations (never re-runs applied ones,
|
||||
# never calls downgrade). This ensures updates are applied automatically.
|
||||
run_migrations
|
||||
|
||||
# ================================================
|
||||
# Environment Variables Info
|
||||
# ================================================
|
||||
echo ""
|
||||
echo "==========================================="
|
||||
echo " 📋 Configuration"
|
||||
echo "==========================================="
|
||||
echo " Frontend URL: http://localhost:3000"
|
||||
echo " Backend API: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
|
||||
echo " API Docs: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}/docs"
|
||||
echo " Electric URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}"
|
||||
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
|
||||
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
|
||||
echo " TTS Service: ${TTS_SERVICE}"
|
||||
echo " STT Service: ${STT_SERVICE}"
|
||||
echo "==========================================="
|
||||
echo ""
|
||||
|
||||
# ================================================
|
||||
# Start Supervisor (manages all services)
|
||||
# ================================================
|
||||
echo "🚀 Starting all services..."
|
||||
exec /usr/local/bin/supervisord -c /etc/supervisor/conf.d/surfsense.conf
|
||||
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
#!/bin/bash
|
||||
# PostgreSQL initialization script for SurfSense
|
||||
# This script is called during container startup if the database needs initialization
|
||||
|
||||
set -e
|
||||
|
||||
PGDATA=${PGDATA:-/data/postgres}
|
||||
POSTGRES_USER=${POSTGRES_USER:-surfsense}
|
||||
POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-surfsense}
|
||||
POSTGRES_DB=${POSTGRES_DB:-surfsense}
|
||||
|
||||
# Electric SQL user credentials (configurable)
|
||||
ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
|
||||
echo "Initializing PostgreSQL..."
|
||||
|
||||
# Check if PostgreSQL is already initialized
|
||||
if [ -f "$PGDATA/PG_VERSION" ]; then
|
||||
echo "PostgreSQL data directory already exists. Skipping initialization."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Initialize the database cluster
|
||||
/usr/lib/postgresql/14/bin/initdb -D "$PGDATA" --username=postgres
|
||||
|
||||
# Configure PostgreSQL
|
||||
cat >> "$PGDATA/postgresql.conf" << EOF
|
||||
listen_addresses = '*'
|
||||
max_connections = 200
|
||||
shared_buffers = 256MB
|
||||
|
||||
# Enable logical replication (required for Electric SQL)
|
||||
wal_level = logical
|
||||
max_replication_slots = 10
|
||||
max_wal_senders = 10
|
||||
|
||||
# Performance settings
|
||||
checkpoint_timeout = 10min
|
||||
max_wal_size = 1GB
|
||||
min_wal_size = 80MB
|
||||
EOF
|
||||
|
||||
cat >> "$PGDATA/pg_hba.conf" << EOF
|
||||
# Allow connections from anywhere with password
|
||||
host all all 0.0.0.0/0 md5
|
||||
host all all ::0/0 md5
|
||||
EOF
|
||||
|
||||
# Start PostgreSQL temporarily
|
||||
/usr/lib/postgresql/14/bin/pg_ctl -D "$PGDATA" -l /tmp/postgres_init.log start
|
||||
|
||||
# Wait for PostgreSQL to start
|
||||
sleep 3
|
||||
|
||||
# Create user and database
|
||||
psql -U postgres << EOF
|
||||
CREATE USER $POSTGRES_USER WITH PASSWORD '$POSTGRES_PASSWORD' SUPERUSER;
|
||||
CREATE DATABASE $POSTGRES_DB OWNER $POSTGRES_USER;
|
||||
\c $POSTGRES_DB
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- Create Electric SQL replication user
|
||||
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
|
||||
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
|
||||
EOF
|
||||
|
||||
echo "PostgreSQL initialized successfully."
|
||||
|
||||
# Stop PostgreSQL (supervisor will start it)
|
||||
/usr/lib/postgresql/14/bin/pg_ctl -D "$PGDATA" stop
|
||||
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
[supervisord]
|
||||
nodaemon=true
|
||||
logfile=/dev/stdout
|
||||
logfile_maxbytes=0
|
||||
pidfile=/var/run/supervisord.pid
|
||||
loglevel=info
|
||||
user=root
|
||||
|
||||
[unix_http_server]
|
||||
file=/var/run/supervisor.sock
|
||||
chmod=0700
|
||||
|
||||
[rpcinterface:supervisor]
|
||||
supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
|
||||
|
||||
[supervisorctl]
|
||||
serverurl=unix:///var/run/supervisor.sock
|
||||
|
||||
# PostgreSQL
|
||||
[program:postgresql]
|
||||
command=/usr/lib/postgresql/14/bin/postgres -D /data/postgres
|
||||
user=postgres
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=10
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=PGDATA="/data/postgres"
|
||||
|
||||
# Redis
|
||||
[program:redis]
|
||||
command=/usr/bin/redis-server --dir /data/redis --appendonly yes
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=20
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
|
||||
# Backend API
|
||||
[program:backend]
|
||||
command=python main.py
|
||||
directory=/app/backend
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=30
|
||||
startsecs=10
|
||||
startretries=3
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=PYTHONPATH="/app/backend",UVICORN_LOOP="asyncio",UNSTRUCTURED_HAS_PATCHED_LOOP="1"
|
||||
|
||||
# Celery Worker
|
||||
[program:celery-worker]
|
||||
command=celery -A app.celery_app worker --loglevel=info --concurrency=2 --pool=solo --queues=surfsense,surfsense.connectors
|
||||
directory=/app/backend
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=40
|
||||
startsecs=15
|
||||
startretries=3
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=PYTHONPATH="/app/backend"
|
||||
|
||||
# Celery Beat (scheduler)
|
||||
[program:celery-beat]
|
||||
command=celery -A app.celery_app beat --loglevel=info
|
||||
directory=/app/backend
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=50
|
||||
startsecs=20
|
||||
startretries=3
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=PYTHONPATH="/app/backend"
|
||||
|
||||
# Electric SQL (real-time sync)
|
||||
[program:electric]
|
||||
command=/app/electric-release/bin/entrypoint start
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=25
|
||||
startsecs=10
|
||||
startretries=3
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=DATABASE_URL="%(ENV_ELECTRIC_DATABASE_URL)s",ELECTRIC_INSECURE="%(ENV_ELECTRIC_INSECURE)s",ELECTRIC_WRITE_TO_PG_MODE="%(ENV_ELECTRIC_WRITE_TO_PG_MODE)s",RELEASE_COOKIE="surfsense_electric_cookie",PORT="%(ENV_ELECTRIC_PORT)s"
|
||||
|
||||
# Frontend
|
||||
[program:frontend]
|
||||
command=node server.js
|
||||
directory=/app/frontend
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=60
|
||||
startsecs=5
|
||||
startretries=3
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=NODE_ENV="production",PORT="3000",HOSTNAME="0.0.0.0"
|
||||
|
||||
# Process Groups
|
||||
[group:surfsense]
|
||||
programs=postgresql,redis,electric,backend,celery-worker,celery-beat,frontend
|
||||
priority=999
|
||||
|
||||
|
|
@ -167,37 +167,12 @@ LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
|||
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||
LANGSMITH_PROJECT=surfsense
|
||||
|
||||
# Uvicorn Server Configuration
|
||||
# Full documentation for Uvicorn options can be found at: https://www.uvicorn.org/#command-line-options
|
||||
UVICORN_HOST="0.0.0.0"
|
||||
UVICORN_PORT=8000
|
||||
UVICORN_LOG_LEVEL=info
|
||||
|
||||
# OPTIONAL: Advanced Uvicorn Options (uncomment to use)
|
||||
# UVICORN_PROXY_HEADERS=false
|
||||
# UVICORN_FORWARDED_ALLOW_IPS="127.0.0.1"
|
||||
# UVICORN_WORKERS=1
|
||||
# UVICORN_ACCESS_LOG=true
|
||||
# UVICORN_LOOP="auto"
|
||||
# UVICORN_HTTP="auto"
|
||||
# UVICORN_WS="auto"
|
||||
# UVICORN_LIFESPAN="auto"
|
||||
# UVICORN_LOG_CONFIG=""
|
||||
# UVICORN_SERVER_HEADER=true
|
||||
# UVICORN_DATE_HEADER=true
|
||||
# UVICORN_LIMIT_CONCURRENCY=
|
||||
# UVICORN_LIMIT_MAX_REQUESTS=
|
||||
# UVICORN_TIMEOUT_KEEP_ALIVE=5
|
||||
# UVICORN_TIMEOUT_NOTIFY=30
|
||||
# UVICORN_SSL_KEYFILE=""
|
||||
# UVICORN_SSL_CERTFILE=""
|
||||
# UVICORN_SSL_KEYFILE_PASSWORD=""
|
||||
# UVICORN_SSL_VERSION=""
|
||||
# UVICORN_SSL_CERT_REQS=""
|
||||
# UVICORN_SSL_CA_CERTS=""
|
||||
# UVICORN_SSL_CIPHERS=""
|
||||
# UVICORN_HEADERS=""
|
||||
# UVICORN_USE_COLORS=true
|
||||
# UVICORN_UDS=""
|
||||
# UVICORN_FD=""
|
||||
# UVICORN_ROOT_PATH=""
|
||||
# Agent Specific Configuration
|
||||
# Daytona Sandbox (secure cloud code execution for deep agent)
|
||||
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
|
||||
DAYTONA_SANDBOX_ENABLED=TRUE
|
||||
DAYTONA_API_KEY=dtn_asdasfasfafas
|
||||
DAYTONA_API_URL=https://app.daytona.io/api
|
||||
DAYTONA_TARGET=us
|
||||
# Directory for locally-persisted sandbox files (after sandbox deletion)
|
||||
SANDBOX_FILES_DIR=sandbox_files
|
||||
|
|
|
|||
1
surfsense_backend/.gitignore
vendored
1
surfsense_backend/.gitignore
vendored
|
|
@ -6,6 +6,7 @@ __pycache__/
|
|||
.flashrank_cache
|
||||
surf_new_backend.egg-info/
|
||||
podcasts/
|
||||
sandbox_files/
|
||||
temp_audio/
|
||||
celerybeat-schedule*
|
||||
celerybeat-schedule.*
|
||||
|
|
|
|||
|
|
@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
|
|||
ENV PYTHONPATH=/app
|
||||
ENV UVICORN_LOOP=asyncio
|
||||
|
||||
# Tune glibc malloc to return freed memory to the OS more aggressively.
|
||||
# Without these, Python's gc.collect() frees objects but the underlying
|
||||
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
|
||||
ENV MALLOC_MMAP_THRESHOLD_=65536
|
||||
ENV MALLOC_TRIM_THRESHOLD_=131072
|
||||
ENV MALLOC_MMAP_MAX_=65536
|
||||
|
||||
# SERVICE_ROLE controls which process this container runs:
|
||||
# api – FastAPI backend only (runs migrations on startup)
|
||||
# worker – Celery worker only
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
"""102_add_enable_summary_to_connectors
|
||||
|
||||
Revision ID: 102
|
||||
Revises: 101
|
||||
Create Date: 2026-02-26
|
||||
|
||||
Adds enable_summary boolean column to search_source_connectors.
|
||||
Defaults to False for all existing and new connectors so LLM-based
|
||||
summary generation is opt-in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "102"
|
||||
down_revision: str | None = "101"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(conn).get_columns("search_source_connectors")
|
||||
]
|
||||
|
||||
if "enable_summary" not in existing_columns:
|
||||
op.add_column(
|
||||
"search_source_connectors",
|
||||
sa.Column(
|
||||
"enable_summary",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_source_connectors", "enable_summary")
|
||||
|
|
@ -8,7 +8,7 @@ Creates notifications table and sets up Electric SQL replication
|
|||
search_source_connectors, and documents tables.
|
||||
|
||||
NOTE: Electric SQL user creation is idempotent (uses IF NOT EXISTS).
|
||||
- Docker deployments: user is pre-created by scripts/docker/init-electric-user.sh
|
||||
- Docker deployments: user is pre-created by docker/scripts/init-electric-user.sh
|
||||
- Local PostgreSQL: user is created here during migration
|
||||
Both approaches are safe to run together without conflicts as this migraiton is idempotent
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,10 +6,14 @@ with configurable tools via the tools registry and configurable prompts
|
|||
via NewLLMConfig.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends.protocol import SandboxBackendProtocol
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
|
@ -24,6 +28,9 @@ from app.agents.new_chat.system_prompt import (
|
|||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
# =============================================================================
|
||||
# Connector Type Mapping
|
||||
|
|
@ -128,6 +135,7 @@ async def create_surfsense_deep_agent(
|
|||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_backend: SandboxBackendProtocol | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -167,6 +175,9 @@ async def create_surfsense_deep_agent(
|
|||
These are always added regardless of enabled/disabled settings.
|
||||
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
|
||||
Falls back to Chromium/Trafilatura if not provided.
|
||||
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
|
||||
secure code execution. When provided, the agent gets an
|
||||
isolated ``execute`` tool for running shell commands.
|
||||
|
||||
Returns:
|
||||
CompiledStateGraph: The configured deep agent
|
||||
|
|
@ -205,32 +216,41 @@ async def create_surfsense_deep_agent(
|
|||
additional_tools=[my_custom_tool]
|
||||
)
|
||||
"""
|
||||
_t_agent_total = time.perf_counter()
|
||||
|
||||
# Discover available connectors and document types for this search space
|
||||
# This enables dynamic tool docstrings that inform the LLM about what's actually available
|
||||
available_connectors: list[str] | None = None
|
||||
available_document_types: list[str] | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
# Get enabled search source connectors for this search space
|
||||
connector_types = await connector_service.get_available_connectors(
|
||||
search_space_id
|
||||
)
|
||||
if connector_types:
|
||||
# Convert enum values to strings and also include mapped document types
|
||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
||||
|
||||
# Get document types that have at least one document indexed
|
||||
available_document_types = await connector_service.get_available_document_types(
|
||||
search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail - fall back to all connectors if discovery fails
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||
_perf_log.info(
|
||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# Build dependencies dict for the tools registry
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
# Extract the model's context window so tools can size their output.
|
||||
_model_profile = getattr(llm, "profile", None)
|
||||
_max_input_tokens: int | None = (
|
||||
_model_profile.get("max_input_tokens")
|
||||
if isinstance(_model_profile, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
dependencies = {
|
||||
"search_space_id": search_space_id,
|
||||
"db_session": db_session,
|
||||
|
|
@ -241,6 +261,7 @@ async def create_surfsense_deep_agent(
|
|||
"thread_visibility": visibility,
|
||||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
"max_input_tokens": _max_input_tokens,
|
||||
}
|
||||
|
||||
# Disable Notion action tools if no Notion connector is configured
|
||||
|
|
@ -269,35 +290,61 @@ async def create_surfsense_deep_agent(
|
|||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
_t0 = time.perf_counter()
|
||||
tools = await build_tools_async(
|
||||
dependencies=dependencies,
|
||||
enabled_tools=enabled_tools,
|
||||
disabled_tools=modified_disabled_tools,
|
||||
additional_tools=list(additional_tools) if additional_tools else None,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] build_tools_async in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
len(tools),
|
||||
)
|
||||
|
||||
# Build system prompt based on agent_config
|
||||
_t0 = time.perf_counter()
|
||||
_sandbox_enabled = sandbox_backend is not None
|
||||
if agent_config is not None:
|
||||
# Use configurable prompt with settings from NewLLMConfig
|
||||
system_prompt = build_configurable_system_prompt(
|
||||
custom_system_instructions=agent_config.system_instructions,
|
||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||
citations_enabled=agent_config.citations_enabled,
|
||||
thread_visibility=thread_visibility,
|
||||
sandbox_enabled=_sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
sandbox_enabled=_sandbox_enabled,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Create the deep agent with system prompt and checkpointer
|
||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
||||
agent = create_deep_agent(
|
||||
# Build optional kwargs for the deep agent
|
||||
deep_agent_kwargs: dict[str, Any] = {}
|
||||
if sandbox_backend is not None:
|
||||
deep_agent_kwargs["backend"] = sandbox_backend
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
create_deep_agent,
|
||||
model=llm,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
**deep_agent_kwargs,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_perf_log.info(
|
||||
"[create_agent] Total agent creation in %.3fs",
|
||||
time.perf_counter() - _t_agent_total,
|
||||
)
|
||||
return agent
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
|
@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
|
|||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
return get_auto_mode_llm()
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
282
surfsense_backend/app/agents/new_chat/sandbox.py
Normal file
282
surfsense_backend/app/agents/new_chat/sandbox.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
"""
|
||||
Daytona sandbox provider for SurfSense deep agent.
|
||||
|
||||
Manages the lifecycle of sandboxed code execution environments.
|
||||
Each conversation thread gets its own isolated sandbox instance
|
||||
via the Daytona cloud API, identified by labels.
|
||||
|
||||
Files created during a session are persisted to local storage before
|
||||
the sandbox is deleted so they remain downloadable after cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
CreateSandboxFromSnapshotParams,
|
||||
Daytona,
|
||||
DaytonaConfig,
|
||||
SandboxState,
|
||||
)
|
||||
from daytona.common.errors import DaytonaError
|
||||
from deepagents.backends.protocol import ExecuteResponse
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||
"""DaytonaSandbox subclass that accepts the per-command *timeout*
|
||||
kwarg required by the deepagents middleware.
|
||||
|
||||
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
|
||||
so deepagents raises *"This sandbox backend does not support
|
||||
per-command timeout overrides"* on every first call. This thin
|
||||
wrapper forwards the parameter to the Daytona SDK.
|
||||
"""
|
||||
|
||||
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
||||
t = timeout if timeout is not None else self._timeout
|
||||
result = self._sandbox.process.exec(command, timeout=t)
|
||||
return ExecuteResponse(
|
||||
output=result.result,
|
||||
exit_code=result.exit_code,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
async def aexecute(
|
||||
self, command: str, *, timeout: int | None = None
|
||||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
||||
|
||||
def is_sandbox_enabled() -> bool:
|
||||
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
||||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
|
||||
"""Find an existing sandbox for *thread_id*, or create a new one.
|
||||
|
||||
If an existing sandbox is found but is stopped/archived, it will be
|
||||
restarted automatically before returning.
|
||||
"""
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: thread_id}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
|
||||
|
||||
if sandbox.state in (
|
||||
SandboxState.STOPPED,
|
||||
SandboxState.STOPPING,
|
||||
SandboxState.ARCHIVED,
|
||||
):
|
||||
logger.info("Starting stopped sandbox %s …", sandbox.id)
|
||||
sandbox.start(timeout=60)
|
||||
logger.info("Sandbox %s is now started", sandbox.id)
|
||||
elif sandbox.state in (
|
||||
SandboxState.ERROR,
|
||||
SandboxState.BUILD_FAILED,
|
||||
SandboxState.DESTROYED,
|
||||
):
|
||||
logger.warning(
|
||||
"Sandbox %s in unrecoverable state %s — creating a new one",
|
||||
sandbox.id,
|
||||
sandbox.state,
|
||||
)
|
||||
sandbox = client.create(
|
||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
||||
)
|
||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except Exception:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(
|
||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
||||
)
|
||||
logger.info("Created new sandbox: %s", sandbox.id)
|
||||
|
||||
return _TimeoutAwareSandbox(sandbox=sandbox)
|
||||
|
||||
|
||||
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
||||
"""Get or create a sandbox for a conversation thread.
|
||||
|
||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||
in the same conversation reuse the sandbox object without an API call.
|
||||
|
||||
Args:
|
||||
thread_id: The conversation thread identifier.
|
||||
|
||||
Returns:
|
||||
DaytonaSandbox connected to the sandbox.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached
|
||||
sandbox = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
_sandbox_cache.pop(oldest_key, None)
|
||||
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
||||
|
||||
return sandbox
|
||||
|
||||
|
||||
async def delete_sandbox(thread_id: int | str) -> None:
|
||||
"""Delete the sandbox for a conversation thread."""
|
||||
_sandbox_cache.pop(str(thread_id), None)
|
||||
|
||||
def _delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except DaytonaError:
|
||||
logger.debug(
|
||||
"No sandbox to delete for thread %s (already removed)", thread_id
|
||||
)
|
||||
return
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_sandbox_files_dir() -> Path:
|
||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||
|
||||
|
||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
||||
relative = sandbox_path.lstrip("/")
|
||||
return _get_sandbox_files_dir() / str(thread_id) / relative
|
||||
|
||||
|
||||
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
||||
"""Read a previously-persisted sandbox file from local storage.
|
||||
|
||||
Returns the file bytes, or *None* if the file does not exist locally.
|
||||
"""
|
||||
local = _local_path_for(thread_id, sandbox_path)
|
||||
if local.is_file():
|
||||
return local.read_bytes()
|
||||
return None
|
||||
|
||||
|
||||
def delete_local_sandbox_files(thread_id: int | str) -> None:
|
||||
"""Remove all locally-persisted sandbox files for a thread."""
|
||||
thread_dir = _get_sandbox_files_dir() / str(thread_id)
|
||||
if thread_dir.is_dir():
|
||||
shutil.rmtree(thread_dir, ignore_errors=True)
|
||||
logger.info("Deleted local sandbox files for thread %s", thread_id)
|
||||
|
||||
|
||||
async def persist_and_delete_sandbox(
|
||||
thread_id: int | str,
|
||||
sandbox_file_paths: list[str],
|
||||
) -> None:
|
||||
"""Download sandbox files to local storage, then delete the sandbox.
|
||||
|
||||
Each file in *sandbox_file_paths* is downloaded from the Daytona
|
||||
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/…``.
|
||||
Per-file errors are logged but do **not** prevent the sandbox from
|
||||
being deleted — freeing Daytona storage is the priority.
|
||||
"""
|
||||
_sandbox_cache.pop(str(thread_id), None)
|
||||
|
||||
def _persist_and_delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except Exception:
|
||||
logger.info(
|
||||
"No sandbox found for thread %s — nothing to persist", thread_id
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the sandbox is running so we can download files
|
||||
if sandbox.state != SandboxState.STARTED:
|
||||
try:
|
||||
sandbox.start(timeout=60)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not start sandbox %s for file download — deleting anyway",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
client.delete(sandbox)
|
||||
return
|
||||
|
||||
for path in sandbox_file_paths:
|
||||
try:
|
||||
content: bytes = sandbox.fs.download_file(path)
|
||||
local = _local_path_for(thread_id, path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(content)
|
||||
logger.info("Persisted sandbox file %s → %s", path, local)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist sandbox file %s for thread %s",
|
||||
path,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox %s after persistence",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_persist_and_delete)
|
||||
|
|
@ -645,6 +645,87 @@ However, from your video learning, it's important to note that asyncio is not su
|
|||
</citation_instructions>
|
||||
"""
|
||||
|
||||
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
|
||||
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
|
||||
SANDBOX_EXECUTION_INSTRUCTIONS = """
|
||||
<code_execution>
|
||||
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
|
||||
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
|
||||
|
||||
## CRITICAL — CODE-FIRST RULE
|
||||
|
||||
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
|
||||
- **Creating a chart, plot, graph, or visualization** → Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
|
||||
- **Data analysis, statistics, or computation** → Write code to compute the answer. Do not do math by hand in text.
|
||||
- **Generating or transforming files** (CSV, PDF, images, etc.) → Write code to create the file.
|
||||
- **Running, testing, or debugging code** → Execute it in the sandbox.
|
||||
|
||||
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
|
||||
|
||||
Example (CORRECT):
|
||||
User: "Create a pie chart of my benefits"
|
||||
→ 1. search_knowledge_base → retrieve benefits data
|
||||
→ 2. Immediately execute Python code (matplotlib) to generate the pie chart
|
||||
→ 3. Return the downloadable file + brief description
|
||||
|
||||
Example (WRONG):
|
||||
User: "Create a pie chart of my benefits"
|
||||
→ 1. search_knowledge_base → retrieve benefits data
|
||||
→ 2. Print a text table with percentages and ask the user if they want a chart ← NEVER do this
|
||||
|
||||
## When to Use Code Execution
|
||||
|
||||
Use the sandbox when the task benefits from actually running code rather than just describing it:
|
||||
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
|
||||
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
|
||||
- **Calculations**: Math, financial modeling, unit conversions, simulations
|
||||
- **Code validation**: Run and test code snippets the user provides or asks about
|
||||
- **File processing**: Parse, transform, or convert data files
|
||||
- **Quick prototyping**: Demonstrate working code for the user's problem
|
||||
- **Package exploration**: Install and test libraries the user is evaluating
|
||||
|
||||
## When NOT to Use Code Execution
|
||||
|
||||
Do not use the sandbox for:
|
||||
- Answering factual questions from your own knowledge
|
||||
- Summarizing or explaining concepts
|
||||
- Simple formatting or text generation tasks
|
||||
- Tasks that don't require running code to answer
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `pip install <package>` to install Python packages as needed
|
||||
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
|
||||
- Always verify a package installed successfully before using it
|
||||
|
||||
## Working Guidelines
|
||||
|
||||
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
|
||||
- **Iterative approach**: For complex tasks, break work into steps — write code, run it, check output, refine
|
||||
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
|
||||
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
|
||||
- **Be efficient**: Install packages once per session. Combine related commands when possible.
|
||||
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
|
||||
|
||||
## Sharing Generated Files
|
||||
|
||||
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
|
||||
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
|
||||
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- Always describe what the file contains in your response text so the user knows what they are downloading.
|
||||
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
|
||||
|
||||
## Data Analytics Best Practices
|
||||
|
||||
When the user asks you to analyze data:
|
||||
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
|
||||
2. Clean and validate before computing (handle nulls, check types)
|
||||
3. Perform the analysis and present results clearly
|
||||
4. Offer follow-up insights or visualizations when appropriate
|
||||
</code_execution>
|
||||
"""
|
||||
|
||||
# Anti-citation prompt - used when citations are disabled
|
||||
# This explicitly tells the model NOT to include citations
|
||||
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
|
||||
|
|
@ -670,6 +751,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_enabled: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -678,10 +760,12 @@ def build_surfsense_system_prompt(
|
|||
- Default system instructions
|
||||
- Tools instructions (always included)
|
||||
- Citation instructions enabled
|
||||
- Sandbox execution instructions (when sandbox_enabled=True)
|
||||
|
||||
Args:
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -691,7 +775,13 @@ def build_surfsense_system_prompt(
|
|||
system_instructions = _get_system_instructions(visibility, today)
|
||||
tools_instructions = _get_tools_instructions(visibility)
|
||||
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
||||
return (
|
||||
system_instructions
|
||||
+ tools_instructions
|
||||
+ citation_instructions
|
||||
+ sandbox_instructions
|
||||
)
|
||||
|
||||
|
||||
def build_configurable_system_prompt(
|
||||
|
|
@ -700,14 +790,16 @@ def build_configurable_system_prompt(
|
|||
citations_enabled: bool = True,
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_enabled: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
||||
The prompt is composed of three parts:
|
||||
The prompt is composed of up to four parts:
|
||||
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
2. Tools Instructions - always included (SURFSENSE_TOOLS_INSTRUCTIONS)
|
||||
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||
4. Sandbox Execution Instructions - when sandbox_enabled=True
|
||||
|
||||
Args:
|
||||
custom_system_instructions: Custom system instructions to use. If empty/None and
|
||||
|
|
@ -719,6 +811,7 @@ def build_configurable_system_prompt(
|
|||
anti-citation instructions (False).
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -727,7 +820,6 @@ def build_configurable_system_prompt(
|
|||
|
||||
# Determine system instructions
|
||||
if custom_system_instructions and custom_system_instructions.strip():
|
||||
# Use custom instructions, injecting the date placeholder if present
|
||||
system_instructions = custom_system_instructions.format(
|
||||
resolved_today=resolved_today
|
||||
)
|
||||
|
|
@ -735,7 +827,6 @@ def build_configurable_system_prompt(
|
|||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
else:
|
||||
# No system instructions (edge case)
|
||||
system_instructions = ""
|
||||
|
||||
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
|
||||
|
|
@ -748,7 +839,14 @@ def build_configurable_system_prompt(
|
|||
else SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||
)
|
||||
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
||||
|
||||
return (
|
||||
system_instructions
|
||||
+ tools_instructions
|
||||
+ citation_instructions
|
||||
+ sandbox_instructions
|
||||
)
|
||||
|
||||
|
||||
def get_default_system_instructions() -> str:
|
||||
|
|
|
|||
|
|
@ -58,7 +58,9 @@ def create_create_google_drive_file_tool(
|
|||
- "Create a Google Doc called 'Meeting Notes'"
|
||||
- "Create a spreadsheet named 'Budget 2026' with some sample data"
|
||||
"""
|
||||
logger.info(f"create_google_drive_file called: name='{name}', type='{file_type}'")
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
|
|
@ -74,7 +76,9 @@ def create_create_google_drive_file_tool(
|
|||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(search_space_id, user_id)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
|
|
@ -100,8 +104,12 @@ def create_create_google_drive_file_tool(
|
|||
}
|
||||
)
|
||||
|
||||
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
|
|
@ -183,7 +191,9 @@ def create_create_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(session=db_session, connector_id=actual_connector_id)
|
||||
client = GoogleDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
|
|
@ -203,7 +213,9 @@ def create_create_google_drive_file_tool(
|
|||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Google Drive file created: id={created.get('id')}, name={created.get('name')}")
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
|
|
|
|||
|
|
@ -52,7 +52,9 @@ def create_delete_google_drive_file_tool(
|
|||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
"""
|
||||
logger.info(f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}")
|
||||
logger.info(
|
||||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
|
|
@ -103,8 +105,12 @@ def create_delete_google_drive_file_tool(
|
|||
}
|
||||
)
|
||||
|
||||
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
|
|
@ -130,11 +136,16 @@ def create_delete_google_drive_file_tool(
|
|||
final_params = decision["args"]
|
||||
|
||||
final_file_id = final_params.get("file_id", file_id)
|
||||
final_connector_id = final_params.get("connector_id", connector_id_from_context)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {"status": "error", "message": "No connector found for this file."}
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this file.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -174,7 +185,9 @@ def create_delete_google_drive_file_tool(
|
|||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Google Drive file deleted (moved to trash): file_id={final_file_id}")
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
|
|
@ -195,7 +208,9 @@ def create_delete_google_drive_file_tool(
|
|||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(f"Deleted document {document_id} from knowledge base")
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ This module provides:
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -17,8 +19,152 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
# Connectors that call external live-search APIs (no local DB / embedding needed).
|
||||
# These are never filtered by available_document_types.
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Patterns that indicate the query has no meaningful search signal.
|
||||
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
|
||||
# of '*' is random noise, so both keyword and semantic search degrade to
|
||||
# arbitrary ordering — large documents (many chunks) dominate by chance.
|
||||
_DEGENERATE_QUERY_RE = re.compile(
|
||||
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
|
||||
)
|
||||
|
||||
# Max chunks per document when doing a recency-based browse instead of
|
||||
# a real search. We want breadth (many docs) over depth (many chunks).
|
||||
_BROWSE_MAX_CHUNKS_PER_DOC = 5
|
||||
|
||||
|
||||
def _is_degenerate_query(query: str) -> bool:
|
||||
"""Return True when the query carries no meaningful search signal.
|
||||
|
||||
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
|
||||
strings, and single-character non-word tokens. These queries cause
|
||||
both keyword search (empty tsquery) and semantic search (meaningless
|
||||
embedding) to return effectively random results.
|
||||
"""
|
||||
stripped = query.strip()
|
||||
if not stripped:
|
||||
return True
|
||||
return bool(_DEGENERATE_QUERY_RE.match(stripped))
|
||||
|
||||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return the most-recent documents (recency-ordered, no search ranking).
|
||||
|
||||
Used as a fallback when the search query is degenerate (e.g. ``*``) and
|
||||
semantic / keyword search would produce arbitrary results. Returns
|
||||
document-grouped dicts in the same shape as ``_combined_rrf_search``
|
||||
so the rest of the pipeline works unchanged.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
.where(*base_conditions)
|
||||
.order_by(Document.updated_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
result = await session.execute(doc_query)
|
||||
documents = result.scalars().unique().all()
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
raw_chunks = chunk_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
|
||||
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc in documents:
|
||||
chunks_list = doc_chunks.get(doc.id, [])
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(
|
||||
c["content"] for c in chunks_list if c.get("content")
|
||||
),
|
||||
"score": 0.0,
|
||||
"chunks": chunks_list,
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(results),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Connector Constants and Normalization
|
||||
|
|
@ -172,12 +318,72 @@ def _normalize_connectors(
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
||||
# Fraction of the model's context window (in characters) that a single tool
|
||||
# result is allowed to occupy. The remainder is reserved for system prompt,
|
||||
# conversation history, and model output. With ~4 chars/token this gives a
|
||||
# tool result ≈ 25 % of the context budget in tokens.
|
||||
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
# Hard-floor / ceiling so the budget is always sensible regardless of what
|
||||
# the model reports.
|
||||
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
|
||||
_MAX_CHUNK_CHARS = 8_000
|
||||
|
||||
# Rank-adaptive per-document budget allocation.
|
||||
# Top-ranked (most relevant) documents get a larger share of the budget so
|
||||
# we pack as much high-quality context as possible.
|
||||
#
|
||||
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
|
||||
#
|
||||
# Examples (128K budget, 8K chunk cap):
|
||||
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
|
||||
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
|
||||
# rank 2 → 24% → 3 chunks |
|
||||
_TOP_DOC_BUDGET_FRACTION = 0.40
|
||||
_RANK_DECAY = 0.35
|
||||
_MIN_CHUNKS_PER_DOC = 3
|
||||
|
||||
|
||||
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
||||
"""Derive a character budget from the model's context window.
|
||||
|
||||
Uses ``litellm.get_model_info`` via the value already resolved by
|
||||
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
|
||||
chain as ``max_input_tokens``. Falls back to a conservative default when
|
||||
the value is unavailable.
|
||||
"""
|
||||
if max_input_tokens is None or max_input_tokens <= 0:
|
||||
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
|
||||
|
||||
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
|
||||
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||
|
||||
|
||||
def format_documents_for_context(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
|
||||
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||
max_chunks_per_doc: int = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieved documents into a readable context string for the LLM.
|
||||
|
||||
Documents are added in order (highest relevance first) until the character
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
|
||||
each document is limited to a dynamically computed chunk cap so a single
|
||||
large document cannot monopolize the output while still maximising the use
|
||||
of available context space.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries from connector search
|
||||
max_chars: Approximate character budget for the entire output.
|
||||
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
||||
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
|
||||
auto-compute per document using a rank-adaptive formula so
|
||||
higher-ranked documents receive more chunks.
|
||||
|
||||
Returns:
|
||||
Formatted string with document contents and metadata
|
||||
|
|
@ -278,39 +484,85 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
|||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Render XML expected by citation instructions
|
||||
# Render XML expected by citation instructions, respecting the char budget.
|
||||
parts: list[str] = []
|
||||
for g in grouped.values():
|
||||
total_chars = 0
|
||||
total_docs = len(grouped)
|
||||
|
||||
for doc_idx, g in enumerate(grouped.values()):
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
is_live_search = g["document_type"] in live_search_connectors
|
||||
|
||||
parts.append("<document>")
|
||||
parts.append("<document_metadata>")
|
||||
parts.append(f" <document_id>{g['document_id']}</document_id>")
|
||||
parts.append(f" <document_type>{g['document_type']}</document_type>")
|
||||
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
|
||||
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
|
||||
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
|
||||
parts.append("</document_metadata>")
|
||||
parts.append("")
|
||||
parts.append("<document_content>")
|
||||
doc_lines: list[str] = [
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_id>{g['document_id']}</document_id>",
|
||||
f" <document_type>{g['document_type']}</document_type>",
|
||||
f" <title><![CDATA[{g['title']}]]></title>",
|
||||
f" <url><![CDATA[{g['url']}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"",
|
||||
"<document_content>",
|
||||
]
|
||||
|
||||
for ch in g["chunks"]:
|
||||
# Rank-adaptive per-document chunk cap: top results get more chunks.
|
||||
if max_chunks_per_doc > 0:
|
||||
chunks_allowed = max_chunks_per_doc
|
||||
else:
|
||||
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
|
||||
max_doc_chars = int(max_chars * doc_fraction)
|
||||
xml_overhead = 500
|
||||
chunks_allowed = max(
|
||||
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
|
||||
_MIN_CHUNKS_PER_DOC,
|
||||
)
|
||||
|
||||
chunks = g["chunks"]
|
||||
if len(chunks) > chunks_allowed:
|
||||
chunks = chunks[:chunks_allowed]
|
||||
|
||||
for ch in chunks:
|
||||
ch_content = ch["content"]
|
||||
# For live search connectors, use the document URL as the chunk id
|
||||
# so the LLM outputs [citation:https://...] which the frontend
|
||||
# renders as a clickable link.
|
||||
if max_chunk_chars and len(ch_content) > max_chunk_chars:
|
||||
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
|
||||
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
|
||||
if ch_id is None:
|
||||
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||
else:
|
||||
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
|
||||
doc_lines.append(
|
||||
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
|
||||
)
|
||||
|
||||
parts.append("</document_content>")
|
||||
parts.append("</document>")
|
||||
parts.append("")
|
||||
doc_lines.extend(["</document_content>", "</document>", ""])
|
||||
|
||||
return "\n".join(parts).strip()
|
||||
doc_xml = "\n".join(doc_lines)
|
||||
doc_len = len(doc_xml)
|
||||
|
||||
if total_chars + doc_len > max_chars:
|
||||
remaining = total_docs - doc_idx
|
||||
if doc_idx == 0:
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
parts.append(
|
||||
f"<!-- Output truncated: {remaining} more document(s) omitted "
|
||||
f"(budget {max_chars} chars). Refine your query or reduce top_k "
|
||||
f"to retrieve different results. -->"
|
||||
)
|
||||
break
|
||||
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
|
||||
result = "\n".join(parts).strip()
|
||||
|
||||
# Hard safety net: if the result is still over budget (e.g. a single massive
|
||||
# first document), forcibly truncate with a closing comment.
|
||||
if len(result) > max_chars:
|
||||
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
|
||||
result = result[: max_chars - len(truncation_msg)] + truncation_msg
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -328,6 +580,8 @@ async def search_knowledge_base_async(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's knowledge base for relevant documents.
|
||||
|
|
@ -345,10 +599,18 @@ async def search_knowledge_base_async(
|
|||
end_date: Optional end datetime (UTC) for filtering documents
|
||||
available_connectors: Optional list of connectors actually available in the search space.
|
||||
If provided, only these connectors will be searched.
|
||||
available_document_types: Optional list of document types that actually have indexed
|
||||
data. When provided, local connectors whose document type is
|
||||
absent are skipped entirely (no embedding / DB round-trip).
|
||||
max_input_tokens: Model context window size (tokens). Used to dynamically
|
||||
size the output so it fits within the model's limits.
|
||||
|
||||
Returns:
|
||||
Formatted string with search results
|
||||
"""
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
|
||||
# Resolve date range (default last 2 years)
|
||||
|
|
@ -361,88 +623,169 @@ async def search_knowledge_base_async(
|
|||
|
||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
|
||||
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
|
||||
"EXTENSION": ("search_extension", True, True, {}),
|
||||
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
|
||||
"FILE": ("search_files", True, True, {}),
|
||||
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
|
||||
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
|
||||
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
|
||||
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
|
||||
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
|
||||
# --- Optimization 1: skip local connectors that have zero indexed documents ---
|
||||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
before_count = len(connectors)
|
||||
connectors = [
|
||||
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
|
||||
]
|
||||
skipped = before_count - len(connectors)
|
||||
if skipped:
|
||||
perf.info(
|
||||
"[kb_search] skipped %d empty connectors (had %d, now %d)",
|
||||
skipped,
|
||||
before_count,
|
||||
len(connectors),
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
|
||||
len(connectors),
|
||||
connectors[:5],
|
||||
search_space_id,
|
||||
top_k,
|
||||
)
|
||||
|
||||
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
|
||||
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
|
||||
# yields an empty tsquery, so both retrieval signals are useless.
|
||||
# Fall back to a recency-ordered browse that returns diverse results.
|
||||
if _is_degenerate_query(query):
|
||||
perf.info(
|
||||
"[kb_search] degenerate query %r detected - falling back to recency browse",
|
||||
query,
|
||||
)
|
||||
local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
if not local_connectors:
|
||||
local_connectors = [None] # type: ignore[list-item]
|
||||
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
_browse_recent_documents(
|
||||
search_space_id=search_space_id,
|
||||
document_type=c,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for c in local_connectors
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
all_documents.extend(docs)
|
||||
|
||||
# Skip dedup + formatting below (browse already returns unique docs)
|
||||
# but still cap output budget.
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(
|
||||
all_documents,
|
||||
max_chars=output_budget,
|
||||
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
|
||||
"budget=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(result),
|
||||
output_budget,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
# Specs for live-search connectors (external APIs, no local DB/embedding).
|
||||
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"SEARXNG_API": ("search_searxng", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
|
||||
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
|
||||
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
|
||||
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
|
||||
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
|
||||
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
|
||||
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
|
||||
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
|
||||
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
|
||||
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
|
||||
"NOTE": ("search_notes", True, True, {}),
|
||||
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
|
||||
"CIRCLEBACK": ("search_circleback", True, True, {}),
|
||||
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
|
||||
"search_composio_google_drive",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
|
||||
"search_composio_google_calendar",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
}
|
||||
|
||||
# Keep a conservative cap to avoid overloading DB/external services.
|
||||
# --- Optimization 2: compute the query embedding once, share across all local searches ---
|
||||
precomputed_embedding: list[float] | None = None
|
||||
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
|
||||
if has_local_connectors:
|
||||
from app.config import config as app_config
|
||||
|
||||
t_embed = time.perf_counter()
|
||||
precomputed_embedding = app_config.embedding_model_instance.embed(query)
|
||||
perf.info(
|
||||
"[kb_search] shared embedding computed in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
max_parallel_searches = 4
|
||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||
|
||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||
spec = connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
is_live = connector in _LIVE_SEARCH_CONNECTORS
|
||||
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
if is_live:
|
||||
spec = live_connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
|
||||
try:
|
||||
# Use isolated session per connector. Shared AsyncSession cannot safely
|
||||
# run concurrent DB operations.
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
isolated_connector_service = ConnectorService(
|
||||
isolated_session, search_space_id
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
chunks = await svc._combined_rrf_search(
|
||||
query_text=query,
|
||||
search_space_id=search_space_id,
|
||||
document_type=connector,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
query_embedding=precomputed_embedding,
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
connector_method = getattr(isolated_connector_service, method_name)
|
||||
_, chunks = await connector_method(**kwargs)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
print(f"Error searching connector {connector}: {e}")
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
t_gather = time.perf_counter()
|
||||
connector_results = await asyncio.gather(
|
||||
*[_search_one_connector(connector) for connector in connectors]
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] all connectors gathered in %.3fs",
|
||||
time.perf_counter() - t_gather,
|
||||
)
|
||||
for chunks in connector_results:
|
||||
all_documents.extend(chunks)
|
||||
|
||||
|
|
@ -488,7 +831,29 @@ async def search_knowledge_base_async(
|
|||
|
||||
deduplicated.append(doc)
|
||||
|
||||
return format_documents_for_context(deduplicated)
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
if len(result) > output_budget:
|
||||
perf.warning(
|
||||
"[kb_search] output STILL exceeds budget after format (%d > %d), "
|
||||
"hard truncation should have fired",
|
||||
len(result),
|
||||
output_budget,
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
|
||||
"budget=%d max_input_tokens=%s space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(result),
|
||||
output_budget,
|
||||
max_input_tokens,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
|
||||
|
|
@ -526,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
|
|||
"""Input schema for the search_knowledge_base tool."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query - be specific and include key terms"
|
||||
description=(
|
||||
"The search query - use specific natural language terms. "
|
||||
"NEVER use wildcards like '*' or '**'; instead describe what you want "
|
||||
"(e.g. 'recent meeting notes' or 'project architecture overview')."
|
||||
),
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of results to retrieve (default: 10)",
|
||||
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
|
||||
)
|
||||
start_date: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -552,6 +921,7 @@ def create_search_knowledge_base_tool(
|
|||
connector_service: ConnectorService,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> StructuredTool:
|
||||
"""
|
||||
Factory function to create the search_knowledge_base tool with injected dependencies.
|
||||
|
|
@ -564,6 +934,8 @@ def create_search_knowledge_base_tool(
|
|||
Used to dynamically generate the tool docstring.
|
||||
available_document_types: Optional list of document types that have data in the search space.
|
||||
Used to inform the LLM about what data exists.
|
||||
max_input_tokens: Model context window (tokens) from litellm model info.
|
||||
Used to dynamically size tool output.
|
||||
|
||||
Returns:
|
||||
A configured StructuredTool instance
|
||||
|
|
@ -590,6 +962,10 @@ Focus searches on these types for best results."""
|
|||
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
|
||||
|
||||
IMPORTANT:
|
||||
- Always craft specific, descriptive search queries using natural language keywords.
|
||||
Good: "quarterly sales report Q3", "Python API authentication design".
|
||||
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
|
||||
- Prefer multiple focused searches over a single broad one with high top_k.
|
||||
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
|
||||
- If `connectors_to_search` is omitted/empty, the system will search broadly.
|
||||
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
|
||||
|
|
@ -605,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
|
||||
# Capture for closure
|
||||
_available_connectors = available_connectors
|
||||
_available_document_types = available_document_types
|
||||
|
||||
async def _search_knowledge_base_impl(
|
||||
query: str,
|
||||
|
|
@ -634,6 +1011,8 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
start_date=parsed_start,
|
||||
end_date=parsed_end,
|
||||
available_connectors=_available_connectors,
|
||||
available_document_types=_available_document_types,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
# Create StructuredTool with dynamic description
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
|
@ -25,6 +26,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_MCP_CACHE_MAX_SIZE = 50
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
def _evict_expired_mcp_cache() -> None:
|
||||
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
|
||||
now = time.monotonic()
|
||||
expired = [
|
||||
k
|
||||
for k, (ts, _) in _mcp_tools_cache.items()
|
||||
if now - ts >= _MCP_CACHE_TTL_SECONDS
|
||||
]
|
||||
for k in expired:
|
||||
del _mcp_tools_cache[k]
|
||||
if expired:
|
||||
logger.debug("Evicted %d expired MCP cache entries", len(expired))
|
||||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str,
|
||||
|
|
@ -355,6 +374,19 @@ async def _load_http_mcp_tools(
|
|||
return tools
|
||||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
Args:
|
||||
search_space_id: If provided, only invalidate for this search space.
|
||||
If None, invalidate all cached MCP tools.
|
||||
"""
|
||||
if search_space_id is not None:
|
||||
_mcp_tools_cache.pop(search_space_id, None)
|
||||
else:
|
||||
_mcp_tools_cache.clear()
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -364,6 +396,9 @@ async def load_mcp_tools(
|
|||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
||||
|
||||
Results are cached per search space for up to 5 minutes to avoid
|
||||
re-spawning MCP server processes on every chat message.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: User's search space ID
|
||||
|
|
@ -372,8 +407,22 @@ async def load_mcp_tools(
|
|||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
_evict_expired_mcp_cache()
|
||||
|
||||
now = time.monotonic()
|
||||
cached = _mcp_tools_cache.get(search_space_id)
|
||||
if cached is not None:
|
||||
cached_at, cached_tools = cached
|
||||
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
||||
logger.info(
|
||||
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
|
||||
search_space_id,
|
||||
len(cached_tools),
|
||||
now - cached_at,
|
||||
)
|
||||
return list(cached_tools)
|
||||
|
||||
try:
|
||||
# Fetch all MCP connectors for this search space
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
|
|
@ -385,27 +434,22 @@ async def load_mcp_tools(
|
|||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
# Early validation: Extract and validate connector config
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
|
||||
# Validate server_config exists and is a dict
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Determine transport type
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
# HTTP-based MCP server
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
else:
|
||||
# stdio-based MCP server (default)
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
|
|
@ -417,6 +461,12 @@ async def load_mcp_tools(
|
|||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
return tools
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ from app.db import ChatVisibility
|
|||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
|
|
@ -55,10 +59,6 @@ from .linear import (
|
|||
)
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .notion import (
|
||||
create_create_notion_page_tool,
|
||||
create_delete_notion_page_tool,
|
||||
|
|
@ -118,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
# Optional: dynamically discovered connectors/document types
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
available_document_types=deps.get("available_document_types"),
|
||||
max_input_tokens=deps.get("max_input_tokens"),
|
||||
),
|
||||
requires=["search_space_id", "db_session", "connector_service"],
|
||||
# Note: available_connectors and available_document_types are optional
|
||||
|
|
@ -144,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
thread_id=deps["thread_id"],
|
||||
connector_service=deps.get("connector_service"),
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
available_document_types=deps.get("available_document_types"),
|
||||
),
|
||||
requires=["search_space_id", "thread_id"],
|
||||
# connector_service and available_connectors are optional —
|
||||
# when missing, source_strategy="kb_search" degrades gracefully to "provided"
|
||||
# connector_service, available_connectors, and available_document_types
|
||||
# are optional — when missing, source_strategy="kb_search" degrades
|
||||
# gracefully to "provided"
|
||||
),
|
||||
# Link preview tool - fetches Open Graph metadata for URLs
|
||||
ToolDefinition(
|
||||
|
|
@ -444,8 +447,18 @@ async def build_tools_async(
|
|||
List of configured tool instances ready for the agent, including MCP tools.
|
||||
|
||||
"""
|
||||
# Build standard tools
|
||||
import time
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log.setLevel(logging.DEBUG)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||
_perf_log.info(
|
||||
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
len(tools),
|
||||
)
|
||||
|
||||
# Load MCP tools if requested and dependencies are available
|
||||
if (
|
||||
|
|
@ -454,10 +467,16 @@ async def build_tools_async(
|
|||
and "search_space_id" in dependencies
|
||||
):
|
||||
try:
|
||||
_t0 = time.perf_counter()
|
||||
mcp_tools = await load_mcp_tools(
|
||||
dependencies["db_session"],
|
||||
dependencies["search_space_id"],
|
||||
)
|
||||
_perf_log.info(
|
||||
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
len(mcp_tools),
|
||||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
|
|||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.db import Report, async_session_maker
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
|
|
@ -559,6 +559,7 @@ def create_generate_report_tool(
|
|||
thread_id: int | None = None,
|
||||
connector_service: ConnectorService | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_report tool with injected dependencies.
|
||||
|
|
@ -716,7 +717,7 @@ def create_generate_report_tool(
|
|||
async def _save_failed_report(error_msg: str) -> int | None:
|
||||
"""Persist a failed report row using a short-lived session."""
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
async with shielded_async_session() as session:
|
||||
failed_report = Report(
|
||||
title=topic,
|
||||
content=None,
|
||||
|
|
@ -750,7 +751,7 @@ def create_generate_report_tool(
|
|||
# ── Phase 1: READ (short-lived session) ──────────────────────
|
||||
# Fetch parent report and LLM config, then close the session
|
||||
# so no DB connection is held during the long LLM call.
|
||||
async with async_session_maker() as read_session:
|
||||
async with shielded_async_session() as read_session:
|
||||
if parent_report_id:
|
||||
parent_report = await read_session.get(Report, parent_report_id)
|
||||
if parent_report:
|
||||
|
|
@ -827,7 +828,7 @@ def create_generate_report_tool(
|
|||
|
||||
# Run all queries in parallel, each with its own session
|
||||
async def _run_single_query(q: str) -> str:
|
||||
async with async_session_maker() as kb_session:
|
||||
async with shielded_async_session() as kb_session:
|
||||
kb_connector_svc = ConnectorService(
|
||||
kb_session, search_space_id
|
||||
)
|
||||
|
|
@ -838,6 +839,7 @@ def create_generate_report_tool(
|
|||
connector_service=kb_connector_svc,
|
||||
top_k=10,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
)
|
||||
|
||||
kb_results = await asyncio.gather(
|
||||
|
|
@ -1026,7 +1028,7 @@ def create_generate_report_tool(
|
|||
|
||||
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
||||
# Save the report to the database, then close the session.
|
||||
async with async_session_maker() as write_session:
|
||||
async with shielded_async_session() as write_session:
|
||||
report = Report(
|
||||
title=topic,
|
||||
content=report_content,
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
|
||||
def format_surfsense_docs_results(results: list[tuple]) -> str:
|
||||
|
|
@ -100,7 +100,7 @@ async def search_surfsense_docs_async(
|
|||
Formatted string with relevant documentation content
|
||||
"""
|
||||
# Get embedding for the query
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
query_embedding = embed_text(query)
|
||||
|
||||
# Vector similarity search on chunks, joining with documents
|
||||
stmt = (
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import MemoryCategory, SharedMemory, User
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ async def save_shared_memory(
|
|||
count = await get_shared_memory_count(db_session, search_space_id)
|
||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
embedding = embed_text(content)
|
||||
row = SharedMemory(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=_to_uuid(created_by_id),
|
||||
|
|
@ -108,7 +108,7 @@ async def recall_shared_memory(
|
|||
if category and category in valid_categories:
|
||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||
if query:
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
query_embedding = embed_text(query)
|
||||
stmt = stmt.order_by(
|
||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import MemoryCategory, UserMemory
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ def create_save_memory_tool(
|
|||
await delete_oldest_memory(db_session, user_id, search_space_id)
|
||||
|
||||
# Generate embedding for the memory
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
embedding = embed_text(content)
|
||||
|
||||
# Create new memory using ORM
|
||||
# The pgvector Vector column type handles embedding conversion automatically
|
||||
|
|
@ -268,7 +268,7 @@ def create_recall_memory_tool(
|
|||
|
||||
if query:
|
||||
# Semantic search using embeddings
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
query_embedding = embed_text(query)
|
||||
|
||||
# Build query with vector similarity
|
||||
stmt = (
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from litellm import aspeech
|
|||
|
||||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
from app.services.llm_service import get_agent_llm
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
|
|
@ -31,7 +31,7 @@ async def create_podcast_transcript(
|
|||
user_prompt = configuration.user_prompt
|
||||
|
||||
# Get search space's document summary LLM
|
||||
llm = await get_document_summary_llm(state.db_session, search_space_id)
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
error_message = (
|
||||
f"No document summary LLM configured for search space {search_space_id}"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
|
@ -15,6 +16,9 @@ from slowapi.errors import RateLimitExceeded
|
|||
from slowapi.middleware import SlowAPIMiddleware
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
|
||||
from app.agents.new_chat.checkpointer import (
|
||||
|
|
@ -28,6 +32,7 @@ from app.routes.auth_routes import router as auth_router
|
|||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
|
||||
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
||||
|
||||
|
|
@ -99,22 +104,24 @@ def _check_rate_limit_memory(
|
|||
now = time.monotonic()
|
||||
|
||||
with _memory_lock:
|
||||
# Evict timestamps outside the current window
|
||||
_memory_rate_limits[key] = [
|
||||
t for t in _memory_rate_limits[key] if now - t < window_seconds
|
||||
]
|
||||
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
|
||||
|
||||
if len(_memory_rate_limits[key]) >= max_requests:
|
||||
if not timestamps:
|
||||
_memory_rate_limits.pop(key, None)
|
||||
else:
|
||||
_memory_rate_limits[key] = timestamps
|
||||
|
||||
if len(timestamps) >= max_requests:
|
||||
rate_limit_logger.warning(
|
||||
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
|
||||
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
|
||||
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="RATE_LIMIT_EXCEEDED",
|
||||
)
|
||||
|
||||
_memory_rate_limits[key].append(now)
|
||||
_memory_rate_limits[key] = [*timestamps, now]
|
||||
|
||||
|
||||
def _check_rate_limit(
|
||||
|
|
@ -175,18 +182,47 @@ def rate_limit_password_reset(request: Request):
|
|||
)
|
||||
|
||||
|
||||
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
||||
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
|
||||
|
||||
This helps pinpoint synchronous code that freezes the entire FastAPI server.
|
||||
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
|
||||
"""
|
||||
import os
|
||||
|
||||
if not os.environ.get("PERF_DEBUG"):
|
||||
return
|
||||
|
||||
_slow_log = logging.getLogger("surfsense.perf.slow")
|
||||
_slow_log.setLevel(logging.WARNING)
|
||||
if not _slow_log.handlers:
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
|
||||
_slow_log.addHandler(_h)
|
||||
_slow_log.propagate = False
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
|
||||
loop.set_debug(True)
|
||||
_slow_log.warning(
|
||||
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
|
||||
"Set PERF_DEBUG='' to disable.",
|
||||
threshold_sec,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
|
||||
# with minimal CPU overhead.
|
||||
gc.set_threshold(700, 10, 5)
|
||||
|
||||
_enable_slow_callback_logging(threshold_sec=0.5)
|
||||
await create_db_and_tables()
|
||||
# Setup LangGraph checkpointer tables for conversation persistence
|
||||
await setup_checkpointer_tables()
|
||||
# Initialize LLM Router for Auto mode load balancing
|
||||
initialize_llm_router()
|
||||
# Initialize Image Generation Router for Auto mode load balancing
|
||||
initialize_image_gen_router()
|
||||
# Seed Surfsense documentation (with timeout so a slow embedding API
|
||||
# doesn't block startup indefinitely and make the container unresponsive)
|
||||
try:
|
||||
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
|
||||
except TimeoutError:
|
||||
|
|
@ -194,8 +230,11 @@ async def lifespan(app: FastAPI):
|
|||
"Surfsense docs seeding timed out after 120s — skipping. "
|
||||
"Docs will be indexed on the next restart."
|
||||
)
|
||||
|
||||
log_system_snapshot("startup_complete")
|
||||
|
||||
yield
|
||||
# Cleanup: close checkpointer connection on shutdown
|
||||
|
||||
await close_checkpointer()
|
||||
|
||||
|
||||
|
|
@ -213,6 +252,63 @@ app = FastAPI(lifespan=lifespan)
|
|||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request-level performance middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logs wall-clock time, method, path, and status for every request so we can
|
||||
# spot slow endpoints in production logs.
|
||||
|
||||
_PERF_SLOW_REQUEST_THRESHOLD = float(
|
||||
__import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
|
||||
)
|
||||
|
||||
|
||||
class RequestPerfMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that logs per-request wall-clock time.
|
||||
|
||||
- ALL requests are logged at DEBUG level.
|
||||
- Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
|
||||
WARNING level with a system snapshot so we can correlate slow responses
|
||||
with CPU/memory usage at that moment.
|
||||
"""
|
||||
|
||||
async def dispatch(
|
||||
self, request: StarletteRequest, call_next: RequestResponseEndpoint
|
||||
) -> StarletteResponse:
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
status = response.status_code
|
||||
|
||||
perf.debug(
|
||||
"[request] %s %s -> %d in %.1fms",
|
||||
method,
|
||||
path,
|
||||
status,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
|
||||
perf.warning(
|
||||
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
|
||||
method,
|
||||
path,
|
||||
status,
|
||||
elapsed_ms,
|
||||
_PERF_SLOW_REQUEST_THRESHOLD,
|
||||
)
|
||||
log_system_snapshot("slow_request")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(RequestPerfMiddleware)
|
||||
|
||||
# Add SlowAPI middleware for automatic rate limiting
|
||||
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
|
||||
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
|
|
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
|
|||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -383,6 +383,7 @@ async def _process_gmail_messages_phase2(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
|
|
@ -415,7 +416,7 @@ async def _process_gmail_messages_phase2(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
|
|
@ -427,10 +428,8 @@ async def _process_gmail_messages_phase2(
|
|||
item["markdown_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
@ -646,6 +645,7 @@ async def index_composio_gmail(
|
|||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
|
|
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
|
|||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -440,7 +440,7 @@ async def index_composio_google_calendar(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
|
|
@ -456,12 +456,10 @@ async def index_composio_google_calendar(
|
|||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = f"Calendar: {item['summary']}\n\nStart: {item['start_time']}\nEnd: {item['end_time']}"
|
||||
if item["location"]:
|
||||
summary_content += f"\nLocation: {item['location']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
summary_content = (
|
||||
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from app.tasks.connector_indexers.base import (
|
|||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -714,6 +715,7 @@ async def index_composio_google_drive(
|
|||
max_items=max_items,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
else:
|
||||
|
|
@ -747,6 +749,7 @@ async def index_composio_google_drive(
|
|||
max_items=max_items,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
|
|
@ -829,6 +832,7 @@ async def _index_composio_drive_delta_sync(
|
|||
max_items: int,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index Google Drive files using delta sync with real-time document status updates.
|
||||
|
|
@ -1079,7 +1083,7 @@ async def _index_composio_drive_delta_sync(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"file_id": item["file_id"],
|
||||
"file_name": item["file_name"],
|
||||
|
|
@ -1090,10 +1094,8 @@ async def _index_composio_drive_delta_sync(
|
|||
markdown_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
|
|
@ -1155,6 +1157,7 @@ async def _index_composio_drive_full_scan(
|
|||
max_items: int,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index Google Drive files using full scan with real-time document status updates."""
|
||||
|
|
@ -1488,7 +1491,7 @@ async def _index_composio_drive_full_scan(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"file_id": item["file_id"],
|
||||
"file_name": item["file_name"],
|
||||
|
|
@ -1499,10 +1502,8 @@ async def _index_composio_drive_full_scan(
|
|||
markdown_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
import anyio
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
|
@ -31,7 +33,7 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
class DocumentType(StrEnum):
|
||||
EXTENSION = "EXTENSION"
|
||||
CRAWLED_URL = "CRAWLED_URL"
|
||||
FILE = "FILE"
|
||||
|
|
@ -60,7 +62,7 @@ class DocumentType(str, Enum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
class SearchSourceConnectorType(str, Enum):
|
||||
class SearchSourceConnectorType(StrEnum):
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
TAVILY_API = "TAVILY_API"
|
||||
SEARXNG_API = "SEARXNG_API"
|
||||
|
|
@ -93,7 +95,7 @@ class SearchSourceConnectorType(str, Enum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
class PodcastStatus(str, Enum):
|
||||
class PodcastStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
|
|
@ -177,7 +179,7 @@ class DocumentStatus:
|
|||
return None
|
||||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
class LiteLLMProvider(StrEnum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
"""
|
||||
|
|
@ -215,7 +217,7 @@ class LiteLLMProvider(str, Enum):
|
|||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ImageGenProvider(str, Enum):
|
||||
class ImageGenProvider(StrEnum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
|
|
@ -233,7 +235,7 @@ class ImageGenProvider(str, Enum):
|
|||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
class LogLevel(StrEnum):
|
||||
DEBUG = "DEBUG"
|
||||
INFO = "INFO"
|
||||
WARNING = "WARNING"
|
||||
|
|
@ -241,13 +243,13 @@ class LogLevel(str, Enum):
|
|||
CRITICAL = "CRITICAL"
|
||||
|
||||
|
||||
class LogStatus(str, Enum):
|
||||
class LogStatus(StrEnum):
|
||||
IN_PROGRESS = "IN_PROGRESS"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class IncentiveTaskType(str, Enum):
|
||||
class IncentiveTaskType(StrEnum):
|
||||
"""
|
||||
Enum for incentive task types that users can complete to earn free pages.
|
||||
Each task can only be completed once per user.
|
||||
|
|
@ -298,7 +300,7 @@ INCENTIVE_TASKS_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
class Permission(StrEnum):
|
||||
"""
|
||||
Granular permissions for search space resources.
|
||||
Use '*' (FULL_ACCESS) to grant all permissions.
|
||||
|
|
@ -471,7 +473,7 @@ class BaseModel(Base):
|
|||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class NewChatMessageRole(str, Enum):
|
||||
class NewChatMessageRole(StrEnum):
|
||||
"""Role enum for new chat messages."""
|
||||
|
||||
USER = "user"
|
||||
|
|
@ -479,7 +481,7 @@ class NewChatMessageRole(str, Enum):
|
|||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ChatVisibility(str, Enum):
|
||||
class ChatVisibility(StrEnum):
|
||||
"""
|
||||
Visibility/sharing level for chat threads.
|
||||
|
||||
|
|
@ -788,7 +790,7 @@ class ChatSessionState(BaseModel):
|
|||
ai_responding_to_user = relationship("User")
|
||||
|
||||
|
||||
class MemoryCategory(str, Enum):
|
||||
class MemoryCategory(StrEnum):
|
||||
"""Categories for user memories."""
|
||||
|
||||
# Using lowercase keys to match PostgreSQL enum values
|
||||
|
|
@ -1317,6 +1319,12 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
config = Column(JSON, nullable=False)
|
||||
|
||||
# Summary generation (LLM-based) - disabled by default to save resources.
|
||||
# When enabled, improves hybrid search quality at the cost of LLM calls.
|
||||
enable_summary = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
|
||||
# Periodic indexing fields
|
||||
periodic_indexing_enabled = Column(Boolean, nullable=False, default=False)
|
||||
indexing_frequency_minutes = Column(Integer, nullable=True)
|
||||
|
|
@ -1850,10 +1858,37 @@ class RefreshToken(Base, TimestampMixin):
|
|||
return not self.is_expired and not self.is_revoked
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
pool_size=30,
|
||||
max_overflow=150,
|
||||
pool_recycle=1800,
|
||||
pool_pre_ping=True,
|
||||
pool_timeout=30,
|
||||
)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def shielded_async_session():
|
||||
"""Cancellation-safe async session context manager.
|
||||
|
||||
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
|
||||
scope when a client disconnects. A plain ``async with async_session_maker()``
|
||||
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
|
||||
scope, orphaning the underlying database connection.
|
||||
|
||||
This wrapper ensures ``session.close()`` always completes by running it
|
||||
inside ``anyio.CancelScope(shield=True)``.
|
||||
"""
|
||||
session = async_session_maker()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
await session.close()
|
||||
|
||||
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
|
|
|
|||
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
|
||||
class UploadDocumentAdapter:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._service = IndexingPipelineService(session)
|
||||
|
||||
async def index(
|
||||
self,
|
||||
markdown_content: str,
|
||||
filename: str,
|
||||
etl_service: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
llm,
|
||||
should_summarize: bool = False,
|
||||
) -> None:
|
||||
connector_doc = ConnectorDocument(
|
||||
title=filename,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=filename,
|
||||
document_type=DocumentType.FILE,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
connector_id=None,
|
||||
should_summarize=should_summarize,
|
||||
should_use_code_chunker=False,
|
||||
fallback_summary=markdown_content[:4000],
|
||||
metadata={
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service,
|
||||
},
|
||||
)
|
||||
|
||||
documents = await self._service.prepare_for_indexing([connector_doc])
|
||||
|
||||
if not documents:
|
||||
raise RuntimeError("prepare_for_indexing returned no documents")
|
||||
|
||||
indexed = await self._service.index(documents[0], connector_doc, llm)
|
||||
|
||||
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
|
||||
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
|
||||
|
||||
indexed.content_needs_reindexing = False
|
||||
await self._session.commit()
|
||||
|
||||
async def reindex(self, document: Document, llm) -> None:
|
||||
"""Re-index an existing document after its source_markdown has been updated."""
|
||||
if not document.source_markdown:
|
||||
raise RuntimeError("Document has no source_markdown to reindex")
|
||||
|
||||
metadata = document.document_metadata or {}
|
||||
|
||||
connector_doc = ConnectorDocument(
|
||||
title=document.title,
|
||||
source_markdown=document.source_markdown,
|
||||
unique_id=document.title,
|
||||
document_type=document.document_type,
|
||||
search_space_id=document.search_space_id,
|
||||
created_by_id=str(document.created_by_id),
|
||||
connector_id=document.connector_id,
|
||||
should_summarize=True,
|
||||
should_use_code_chunker=False,
|
||||
fallback_summary=document.source_markdown[:4000],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
document.content_hash = compute_content_hash(connector_doc)
|
||||
|
||||
indexed = await self._service.index(document, connector_doc, llm)
|
||||
|
||||
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
|
||||
raise RuntimeError(indexed.status.get("reason", "Reindexing failed"))
|
||||
|
||||
indexed.content_needs_reindexing = False
|
||||
await self._session.commit()
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
||||
class ConnectorDocument(BaseModel):
|
||||
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
||||
|
||||
title: str
|
||||
source_markdown: str
|
||||
unique_id: str
|
||||
document_type: DocumentType
|
||||
search_space_id: int = Field(gt=0)
|
||||
should_summarize: bool = True
|
||||
should_use_code_chunker: bool = False
|
||||
fallback_summary: str | None = None
|
||||
metadata: dict = {}
|
||||
connector_id: int | None = None
|
||||
created_by_id: str
|
||||
|
||||
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
|
||||
@classmethod
|
||||
def not_empty(cls, v: str, info) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError(f"{info.field_name} must not be empty or whitespace")
|
||||
return v
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
chunker = (
|
||||
config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||
)
|
||||
return [c.text for c in chunker.chunk(text)]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from app.utils.document_converters import embed_text
|
||||
|
||||
__all__ = ["embed_text"]
|
||||
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
import hashlib
|
||||
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a SHA-256 hash of the document's content scoped to its search space."""
|
||||
combined = f"{doc.search_space_id}:{doc.source_markdown}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
|
||||
|
||||
async def rollback_and_persist_failure(
|
||||
session: AsyncSession, document: Document, message: str
|
||||
) -> None:
|
||||
"""Roll back the current transaction and best-effort persist a failed status.
|
||||
|
||||
Called exclusively from except blocks — must never raise, or the new exception
|
||||
would chain with the original and mask it entirely.
|
||||
"""
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
return # Session is completely dead; nothing further we can do.
|
||||
try:
|
||||
await session.refresh(document)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass # Best-effort; document will be retried on the next sync.
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import optimize_content_for_context_window
|
||||
|
||||
|
||||
async def summarize_document(
|
||||
source_markdown: str, llm, metadata: dict | None = None
|
||||
) -> str:
|
||||
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
||||
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
||||
optimized_content = optimize_content_for_context_window(
|
||||
source_markdown, metadata, model_name
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
content_with_metadata = (
|
||||
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
|
||||
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
||||
)
|
||||
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
||||
summary_content = summary_result.content
|
||||
|
||||
if metadata:
|
||||
metadata_parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in metadata.items():
|
||||
if value:
|
||||
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
|
||||
return summary_content
|
||||
146
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
146
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
AuthenticationError,
|
||||
BadGatewayError,
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
Timeout,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
from sqlalchemy.exc import IntegrityError as IntegrityError
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
ServiceUnavailableError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
APIConnectionError,
|
||||
)
|
||||
|
||||
PERMANENT_LLM_ERRORS = (
|
||||
AuthenticationError,
|
||||
PermissionDeniedError,
|
||||
NotFoundError,
|
||||
BadRequestError,
|
||||
UnprocessableEntityError,
|
||||
APIResponseValidationError,
|
||||
)
|
||||
|
||||
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
|
||||
EMBEDDING_ERRORS = (
|
||||
RuntimeError, # local device failure or API backend normalization
|
||||
OSError, # model files missing or corrupted (local backends)
|
||||
MemoryError, # document too large for available RAM
|
||||
OSError, # model files missing or corrupted (local backends)
|
||||
MemoryError, # document too large for available RAM
|
||||
)
|
||||
|
||||
|
||||
class PipelineMessages:
|
||||
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
|
||||
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
|
||||
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
|
||||
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
|
||||
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
|
||||
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
|
||||
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
|
||||
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
|
||||
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
|
||||
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
|
||||
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
|
||||
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
|
||||
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
)
|
||||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
)
|
||||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
|
||||
EMBEDDING_FAILED = (
|
||||
"Embedding failed. Check your embedding model configuration or service."
|
||||
)
|
||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||
EMBEDDING_FAILED = (
|
||||
"Embedding failed. Check your embedding model configuration or service."
|
||||
)
|
||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||
|
||||
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
|
||||
|
||||
|
||||
def safe_exception_message(exc: Exception) -> str:
|
||||
try:
|
||||
return str(exc)
|
||||
except Exception:
|
||||
return "Something went wrong during indexing. Error details could not be retrieved."
|
||||
|
||||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def embedding_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RuntimeError):
|
||||
return PipelineMessages.EMBEDDING_FAILED
|
||||
if isinstance(exc, OSError):
|
||||
return PipelineMessages.EMBEDDING_MODEL
|
||||
if isinstance(exc, MemoryError):
|
||||
return PipelineMessages.EMBEDDING_MEMORY
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when generating the embedding."
|
||||
|
|
@ -0,0 +1,272 @@
|
|||
import contextlib
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_text
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
EMBEDDING_ERRORS,
|
||||
PERMANENT_LLM_ERRORS,
|
||||
RETRYABLE_LLM_ERRORS,
|
||||
PipelineMessages,
|
||||
embedding_message,
|
||||
llm_permanent_message,
|
||||
llm_retryable_message,
|
||||
safe_exception_message,
|
||||
)
|
||||
from app.indexing_pipeline.pipeline_logger import (
|
||||
PipelineLogContext,
|
||||
log_batch_aborted,
|
||||
log_chunking_overflow,
|
||||
log_doc_skipped_unknown,
|
||||
log_document_queued,
|
||||
log_document_requeued,
|
||||
log_document_updated,
|
||||
log_embedding_error,
|
||||
log_index_started,
|
||||
log_index_success,
|
||||
log_permanent_llm_error,
|
||||
log_race_condition,
|
||||
log_retryable_llm_error,
|
||||
log_unexpected_error,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Persist new documents and detect changes, returning only those that need indexing.
|
||||
"""
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
documents = []
|
||||
seen_hashes: set[str] = set()
|
||||
batch_ctx = PipelineLogContext(
|
||||
connector_id=connector_docs[0].connector_id if connector_docs else 0,
|
||||
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
|
||||
unique_id="batch",
|
||||
)
|
||||
|
||||
for connector_doc in connector_docs:
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
)
|
||||
try:
|
||||
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
content_hash = compute_content_hash(connector_doc)
|
||||
|
||||
if unique_identifier_hash in seen_hashes:
|
||||
continue
|
||||
seen_hashes.add(unique_identifier_hash)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(
|
||||
Document.unique_identifier_hash == unique_identifier_hash
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing is not None:
|
||||
if existing.content_hash == content_hash:
|
||||
if existing.title != connector_doc.title:
|
||||
existing.title = connector_doc.title
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
if not DocumentStatus.is_state(
|
||||
existing.status, DocumentStatus.READY
|
||||
):
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
documents.append(existing)
|
||||
log_document_requeued(ctx)
|
||||
continue
|
||||
|
||||
existing.title = connector_doc.title
|
||||
existing.content_hash = content_hash
|
||||
existing.source_markdown = connector_doc.source_markdown
|
||||
existing.document_metadata = connector_doc.metadata
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
existing.status = DocumentStatus.pending()
|
||||
documents.append(existing)
|
||||
log_document_updated(ctx)
|
||||
continue
|
||||
|
||||
duplicate = await self.session.execute(
|
||||
select(Document).filter(Document.content_hash == content_hash)
|
||||
)
|
||||
if duplicate.scalars().first() is not None:
|
||||
continue
|
||||
|
||||
document = Document(
|
||||
title=connector_doc.title,
|
||||
document_type=connector_doc.document_type,
|
||||
content="Pending...",
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=connector_doc.source_markdown,
|
||||
document_metadata=connector_doc.metadata,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
connector_id=connector_doc.connector_id,
|
||||
created_by_id=connector_doc.created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
status=DocumentStatus.pending(),
|
||||
)
|
||||
self.session.add(document)
|
||||
documents.append(document)
|
||||
log_document_queued(ctx)
|
||||
|
||||
except Exception as e:
|
||||
log_doc_skipped_unknown(ctx, e)
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(connector_docs),
|
||||
len(documents),
|
||||
)
|
||||
return documents
|
||||
except IntegrityError:
|
||||
log_race_condition(batch_ctx)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
except Exception as e:
|
||||
log_batch_aborted(batch_ctx, e)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> Document:
|
||||
"""
|
||||
Run summarization, embedding, and chunking for a document and persist the results.
|
||||
"""
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
doc_id=document.id,
|
||||
)
|
||||
perf = get_perf_logger()
|
||||
t_index = time.perf_counter()
|
||||
try:
|
||||
log_index_started(ctx)
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
t_step = time.perf_counter()
|
||||
if connector_doc.should_summarize and llm is not None:
|
||||
content = await summarize_document(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] summarize_document doc=%d in %.3fs",
|
||||
document.id,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
||||
content = connector_doc.fallback_summary
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
t_step = time.perf_counter()
|
||||
embedding = embed_text(content)
|
||||
perf.debug(
|
||||
"[indexing] embed_text (summary) doc=%d in %.3fs",
|
||||
document.id,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
t_step = time.perf_counter()
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=embed_text(text))
|
||||
for text in chunk_text(
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
]
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
document.content = content
|
||||
document.embedding = embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_index,
|
||||
)
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
log_retryable_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_retryable_message(e)
|
||||
)
|
||||
|
||||
except PERMANENT_LLM_ERRORS as e:
|
||||
log_permanent_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_permanent_message(e)
|
||||
)
|
||||
|
||||
except RecursionError as e:
|
||||
log_chunking_overflow(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
||||
)
|
||||
|
||||
except EMBEDDING_ERRORS as e:
|
||||
log_embedding_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, embedding_message(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_unexpected_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, safe_exception_message(e)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
126
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
126
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineLogContext:
|
||||
connector_id: int | None
|
||||
search_space_id: int
|
||||
unique_id: str # always available from ConnectorDocument
|
||||
doc_id: int | None = None # set once the DB row exists (index phase only)
|
||||
|
||||
|
||||
class LogMessages:
|
||||
# prepare_for_indexing
|
||||
DOCUMENT_QUEUED = "New document queued for indexing."
|
||||
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
|
||||
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
|
||||
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
|
||||
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
|
||||
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
|
||||
|
||||
# index
|
||||
INDEX_STARTED = "Document indexing started."
|
||||
INDEX_SUCCESS = "Document indexed successfully."
|
||||
LLM_RETRYABLE = (
|
||||
"Retryable LLM error — document marked failed, will retry on next sync."
|
||||
)
|
||||
LLM_PERMANENT = "Permanent LLM error — document marked failed."
|
||||
EMBEDDING_FAILED = "Embedding error — document marked failed."
|
||||
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
|
||||
UNEXPECTED = "Unexpected error — document marked failed."
|
||||
|
||||
|
||||
def _format_context(ctx: PipelineLogContext) -> str:
|
||||
parts = [
|
||||
f"connector_id={ctx.connector_id}",
|
||||
f"search_space_id={ctx.search_space_id}",
|
||||
f"unique_id={ctx.unique_id}",
|
||||
]
|
||||
if ctx.doc_id is not None:
|
||||
parts.append(f"doc_id={ctx.doc_id}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
|
||||
try:
|
||||
parts = [msg, _format_context(ctx)]
|
||||
for key, val in extra.items():
|
||||
parts.append(f"{key}={val}")
|
||||
return " ".join(parts)
|
||||
except Exception:
|
||||
return msg
|
||||
|
||||
|
||||
def _safe_log(
|
||||
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
|
||||
) -> None:
|
||||
# Logging must never raise — a broken log call inside an except block would
|
||||
# chain with the original exception and mask it entirely.
|
||||
try:
|
||||
message = _build_message(msg, ctx, **extra)
|
||||
level_fn(message, exc_info=exc_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── prepare_for_indexing ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_document_queued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
|
||||
|
||||
|
||||
def log_document_updated(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
|
||||
|
||||
|
||||
def log_document_requeued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
|
||||
|
||||
|
||||
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(
|
||||
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
|
||||
)
|
||||
|
||||
|
||||
def log_race_condition(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
|
||||
|
||||
|
||||
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
# ── index ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_index_started(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
|
||||
|
||||
|
||||
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
|
||||
|
||||
|
||||
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)
|
||||
|
|
@ -1,5 +1,10 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
||||
|
||||
|
||||
class ChucksHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
|
|
@ -38,9 +43,17 @@ class ChucksHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Chunk, Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf.debug(
|
||||
"[chunk_search] vector_search embedding in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
# Build the query filtered by search space
|
||||
query = (
|
||||
|
|
@ -60,8 +73,16 @@ class ChucksHybridSearchRetriever:
|
|||
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
t_db = time.perf_counter()
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
perf.info(
|
||||
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
|
||||
time.perf_counter() - t_db,
|
||||
len(chunks),
|
||||
time.perf_counter() - t0,
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
@ -91,6 +112,9 @@ class ChucksHybridSearchRetriever:
|
|||
|
||||
from app.db import Chunk, Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
|
@ -118,6 +142,12 @@ class ChucksHybridSearchRetriever:
|
|||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
perf.info(
|
||||
"[chunk_search] full_text_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(chunks),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
@ -129,6 +159,7 @@ class ChucksHybridSearchRetriever:
|
|||
document_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
|
@ -143,6 +174,7 @@ class ChucksHybridSearchRetriever:
|
|||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing document data and relevance scores. Each dict contains:
|
||||
|
|
@ -157,9 +189,17 @@ class ChucksHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if query_embedding is None:
|
||||
embedding_model = config.embedding_model_instance
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf.debug(
|
||||
"[chunk_search] hybrid_search embedding in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
# RRF constants
|
||||
k = 60
|
||||
|
|
@ -254,9 +294,17 @@ class ChucksHybridSearchRetriever:
|
|||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
# Execute the RRF query
|
||||
t_rrf = time.perf_counter()
|
||||
result = await self.db_session.execute(final_query)
|
||||
chunks_with_scores = result.all()
|
||||
perf.info(
|
||||
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
|
||||
time.perf_counter() - t_rrf,
|
||||
len(chunks_with_scores),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not chunks_with_scores:
|
||||
|
|
@ -300,8 +348,9 @@ class ChucksHybridSearchRetriever:
|
|||
if not doc_ids:
|
||||
return []
|
||||
|
||||
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
|
||||
# any chunk from those documents.
|
||||
# Fetch chunks for selected documents. We cap per document to avoid
|
||||
# loading hundreds of chunks for a single large file while still
|
||||
# ensuring the chunks that matched the RRF query are always included.
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
|
|
@ -311,7 +360,20 @@ class ChucksHybridSearchRetriever:
|
|||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunk_query)
|
||||
all_chunks = chunks_result.scalars().all()
|
||||
raw_chunks = chunks_result.scalars().all()
|
||||
|
||||
matched_chunk_ids: set[int] = {
|
||||
item["chunk_id"] for item in serialized_chunk_results
|
||||
}
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
all_chunks: list = []
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
|
||||
all_chunks.append(chunk)
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
# Assemble final doc-grouped results in the same order as doc_ids
|
||||
doc_map: dict[int, dict] = {
|
||||
|
|
@ -354,4 +416,11 @@ class ChucksHybridSearchRetriever:
|
|||
)
|
||||
final_docs.append(entry)
|
||||
|
||||
perf.info(
|
||||
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(final_docs),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
||||
|
||||
|
||||
class DocumentHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
|
|
@ -38,6 +43,9 @@ class DocumentHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
|
@ -63,6 +71,12 @@ class DocumentHybridSearchRetriever:
|
|||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
perf.info(
|
||||
"[doc_search] vector_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(documents),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
|
@ -92,6 +106,9 @@ class DocumentHybridSearchRetriever:
|
|||
|
||||
from app.db import Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
|
@ -118,6 +135,12 @@ class DocumentHybridSearchRetriever:
|
|||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
perf.info(
|
||||
"[doc_search] full_text_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(documents),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
|
@ -129,6 +152,7 @@ class DocumentHybridSearchRetriever:
|
|||
document_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
|
@ -143,7 +167,7 @@ class DocumentHybridSearchRetriever:
|
|||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
||||
"""
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
|
@ -151,9 +175,12 @@ class DocumentHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if query_embedding is None:
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# RRF constants
|
||||
k = 60
|
||||
|
|
@ -254,7 +281,8 @@ class DocumentHybridSearchRetriever:
|
|||
# Collect document IDs for chunk fetching
|
||||
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
||||
|
||||
# Fetch ALL chunks for these documents in a single query
|
||||
# Fetch chunks for these documents, capped per document to avoid
|
||||
# loading hundreds of chunks for a single large file.
|
||||
chunks_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
|
|
@ -262,7 +290,16 @@ class DocumentHybridSearchRetriever:
|
|||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
raw_chunks = chunks_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
chunks: list = []
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _MAX_FETCH_CHUNKS_PER_DOC:
|
||||
chunks.append(chunk)
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
# Assemble doc-grouped results
|
||||
doc_map: dict[int, dict] = {
|
||||
|
|
@ -303,4 +340,11 @@ class DocumentHybridSearchRetriever:
|
|||
)
|
||||
final_docs.append(entry)
|
||||
|
||||
perf.info(
|
||||
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(final_docs),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from .podcasts_routes import router as podcasts_router
|
|||
from .public_chat_routes import router as public_chat_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
from .reports_routes import router as reports_router
|
||||
from .sandbox_routes import router as sandbox_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
|
|
@ -50,6 +51,7 @@ router.include_router(editor_router)
|
|||
router.include_router(documents_router)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(reports_router) # Report CRUD and export (PDF/DOCX)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentBatchRequest,
|
||||
CommentBatchResponse,
|
||||
CommentCreateRequest,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
|
|
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
|
|||
create_reply,
|
||||
delete_comment,
|
||||
get_comments_for_message,
|
||||
get_comments_for_messages_batch,
|
||||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
|
|
@ -27,6 +30,16 @@ from app.users import current_active_user
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
|
||||
async def batch_list_comments(
|
||||
request: CommentBatchRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Batch-fetch comments for multiple messages in one request."""
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, user)
|
||||
|
||||
|
||||
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
|
||||
async def list_comments(
|
||||
message_id: int,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from app.schemas import (
|
|||
DocumentWithChunksRead,
|
||||
PaginatedResponse,
|
||||
)
|
||||
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -44,6 +45,10 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_FILES_PER_UPLOAD = 10
|
||||
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
|
||||
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
|
||||
|
||||
|
||||
@router.post("/documents")
|
||||
async def create_documents(
|
||||
|
|
@ -114,8 +119,10 @@ async def create_documents(
|
|||
async def create_documents_file_upload(
|
||||
files: list[UploadFile],
|
||||
search_space_id: int = Form(...),
|
||||
should_summarize: bool = Form(False),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
|
||||
):
|
||||
"""
|
||||
Upload files as documents with real-time status tracking.
|
||||
|
|
@ -126,6 +133,8 @@ async def create_documents_file_upload(
|
|||
|
||||
Requires DOCUMENTS_CREATE permission.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
|
||||
from app.db import DocumentStatus
|
||||
|
|
@ -136,7 +145,6 @@ async def create_documents_file_upload(
|
|||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
try:
|
||||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
|
|
@ -148,51 +156,88 @@ async def create_documents_file_upload(
|
|||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > MAX_FILES_PER_UPLOAD:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
|
||||
)
|
||||
|
||||
total_size = 0
|
||||
for file in files:
|
||||
file_size = file.size or 0
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
total_size += file_size
|
||||
|
||||
if total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
# ===== Read all files concurrently to avoid blocking the event loop =====
|
||||
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
|
||||
"""Read upload content and write to temp file off the event loop."""
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename or "unknown"
|
||||
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
|
||||
def _write_temp() -> str:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=os.path.splitext(filename)[1]
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
return tmp.name
|
||||
|
||||
temp_path = await asyncio.to_thread(_write_temp)
|
||||
return temp_path, filename, file_size
|
||||
|
||||
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
|
||||
|
||||
actual_total_size = sum(size for _, _, size in saved_files)
|
||||
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
for temp_path, _, _ in saved_files:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
created_documents: list[Document] = []
|
||||
files_to_process: list[
|
||||
tuple[Document, str, str]
|
||||
] = [] # (document, temp_path, filename)
|
||||
files_to_process: list[tuple[Document, str, str]] = []
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
for file in files:
|
||||
for temp_path, filename, file_size in saved_files:
|
||||
try:
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Save file to temp location
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=os.path.splitext(file.filename or "")[1]
|
||||
) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
content = await file.read()
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
file_size = len(content)
|
||||
|
||||
# Generate unique identifier for deduplication check
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file.filename or "unknown", search_space_id
|
||||
DocumentType.FILE, filename, search_space_id
|
||||
)
|
||||
|
||||
# Check if document already exists (by unique identifier)
|
||||
existing = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
if existing:
|
||||
if DocumentStatus.is_state(existing.status, DocumentStatus.READY):
|
||||
# True duplicate — content already indexed, skip
|
||||
os.unlink(temp_path)
|
||||
skipped_duplicates += 1
|
||||
duplicate_document_ids.append(existing.id)
|
||||
continue
|
||||
|
||||
# Existing document is stuck (failed/pending/processing)
|
||||
# Reset it to pending and re-dispatch for processing
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.content = "Processing..."
|
||||
existing.document_metadata = {
|
||||
|
|
@ -202,61 +247,53 @@ async def create_documents_file_upload(
|
|||
}
|
||||
existing.updated_at = get_current_timestamp()
|
||||
created_documents.append(existing)
|
||||
files_to_process.append(
|
||||
(existing, temp_path, file.filename or "unknown")
|
||||
)
|
||||
files_to_process.append((existing, temp_path, filename))
|
||||
continue
|
||||
|
||||
# Create pending document (visible immediately in UI via ElectricSQL)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file.filename or "Uploaded File",
|
||||
title=filename if filename != "unknown" else "Uploaded File",
|
||||
document_type=DocumentType.FILE,
|
||||
document_metadata={
|
||||
"FILE_NAME": file.filename,
|
||||
"FILE_NAME": filename,
|
||||
"file_size": file_size,
|
||||
"upload_time": datetime.now().isoformat(),
|
||||
},
|
||||
content="Processing...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary, updated when ready
|
||||
content="Processing...",
|
||||
content_hash=unique_identifier_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
status=DocumentStatus.pending(), # Shows "pending" in UI
|
||||
status=DocumentStatus.pending(),
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=str(user.id),
|
||||
)
|
||||
session.add(document)
|
||||
created_documents.append(document)
|
||||
files_to_process.append(
|
||||
(document, temp_path, file.filename or "unknown")
|
||||
)
|
||||
files_to_process.append((document, temp_path, filename))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {e!s}",
|
||||
detail=f"Failed to process file {filename}: {e!s}",
|
||||
) from e
|
||||
|
||||
# Commit all pending documents - they appear in UI immediately via ElectricSQL
|
||||
if created_documents:
|
||||
await session.commit()
|
||||
# Refresh to get generated IDs
|
||||
for doc in created_documents:
|
||||
await session.refresh(doc)
|
||||
|
||||
# ===== PHASE 2: Dispatch Celery tasks for each file =====
|
||||
# Each task will update document status: pending → processing → ready/failed
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
process_file_upload_with_document_task,
|
||||
)
|
||||
|
||||
# ===== PHASE 2: Dispatch tasks for each file =====
|
||||
for document, temp_path, filename in files_to_process:
|
||||
process_file_upload_with_document_task.delay(
|
||||
await dispatcher.dispatch_file_processing(
|
||||
document_id=document.id,
|
||||
temp_path=temp_path,
|
||||
filename=filename,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(user.id),
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
- POST /threads/{thread_id}/messages - Append message
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
|
@ -30,6 +32,7 @@ from app.db import (
|
|||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
NewChatMessageAppend,
|
||||
|
|
@ -52,9 +55,50 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
|||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
delete_local_sandbox_files,
|
||||
delete_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
return
|
||||
|
||||
async def _bg() -> None:
|
||||
try:
|
||||
await delete_sandbox(thread_id)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Background sandbox delete failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
delete_local_sandbox_files(thread_id)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Local sandbox file cleanup failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(_bg())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def check_thread_access(
|
||||
session: AsyncSession,
|
||||
thread: NewChatThread,
|
||||
|
|
@ -648,6 +692,9 @@ async def delete_thread(
|
|||
|
||||
await session.delete(db_thread)
|
||||
await session.commit()
|
||||
|
||||
_try_delete_sandbox(thread_id)
|
||||
|
||||
return {"message": "Thread deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -1046,13 +1093,18 @@ async def handle_new_chat(
|
|||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs usable.
|
||||
await session.commit()
|
||||
# Close the dependency session now so its connection returns to
|
||||
# the pool before streaming begins. Without this, Starlette's
|
||||
# BaseHTTPMiddleware cancels the scope on client disconnect and
|
||||
# the dependency generator's __aexit__ never runs, orphaning the
|
||||
# connection (the "Exception terminating connection" errors).
|
||||
await session.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
|
|
@ -1277,6 +1329,7 @@ async def regenerate_response(
|
|||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
||||
await session.commit()
|
||||
await session.close()
|
||||
|
||||
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
|
||||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||
|
|
@ -1287,7 +1340,6 @@ async def regenerate_response(
|
|||
user_query=user_query_to_use,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=thread_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
|
|
@ -1298,29 +1350,35 @@ async def regenerate_response(
|
|||
current_user_display_name=user.display_name or "A team member",
|
||||
):
|
||||
yield chunk
|
||||
# If we get here, streaming completed successfully
|
||||
streaming_completed = True
|
||||
finally:
|
||||
# Only delete old messages if streaming completed successfully
|
||||
# This ensures we don't lose data on streaming failures
|
||||
if streaming_completed and messages_to_delete:
|
||||
# Only delete old messages if streaming completed successfully.
|
||||
# Uses a fresh session since stream_new_chat manages its own.
|
||||
if streaming_completed and message_ids_to_delete:
|
||||
try:
|
||||
for msg in messages_to_delete:
|
||||
await session.delete(msg)
|
||||
await session.commit()
|
||||
async with shielded_async_session() as cleanup_session:
|
||||
for msg_id in message_ids_to_delete:
|
||||
_res = await cleanup_session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.id == msg_id
|
||||
)
|
||||
)
|
||||
_msg = _res.scalars().first()
|
||||
if _msg:
|
||||
await cleanup_session.delete(_msg)
|
||||
await cleanup_session.commit()
|
||||
|
||||
# Delete any public snapshots that contain the modified messages
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
|
||||
await delete_affected_snapshots(
|
||||
session, thread_id, message_ids_to_delete
|
||||
)
|
||||
await delete_affected_snapshots(
|
||||
cleanup_session, thread_id, message_ids_to_delete
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
# Log but don't fail - the new messages are already streamed
|
||||
print(
|
||||
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
|
||||
_logger.warning(
|
||||
"[regenerate] Failed to delete old messages: %s",
|
||||
cleanup_error,
|
||||
)
|
||||
|
||||
# Return streaming response with checkpoint_id for rewinding
|
||||
|
|
@ -1394,13 +1452,13 @@ async def resume_chat(
|
|||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
await session.commit()
|
||||
await session.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_resume_chat(
|
||||
chat_id=thread_id,
|
||||
search_space_id=request.search_space_id,
|
||||
decisions=decisions,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
import pypandoc
|
||||
import typst
|
||||
|
|
@ -46,7 +46,7 @@ router = APIRouter()
|
|||
MAX_REPORT_LIST_LIMIT = 500
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
class ExportFormat(StrEnum):
|
||||
PDF = "pdf"
|
||||
DOCX = "docx"
|
||||
|
||||
|
|
|
|||
105
surfsense_backend/app/routes/sandbox_routes.py
Normal file
105
surfsense_backend/app/routes/sandbox_routes.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Routes for downloading files from Daytona sandbox environments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import NewChatThread, Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MIME_TYPES: dict[str, str] = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
".pdf": "application/pdf",
|
||||
".csv": "text/csv",
|
||||
".json": "application/json",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
".py": "text/x-python",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".zip": "application/zip",
|
||||
}
|
||||
|
||||
|
||||
def _guess_media_type(filename: str) -> str:
|
||||
ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else ""
|
||||
return MIME_TYPES.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/sandbox/download")
|
||||
async def download_sandbox_file(
|
||||
thread_id: int,
|
||||
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Download a file from the Daytona sandbox associated with a chat thread."""
|
||||
|
||||
from app.agents.new_chat.sandbox import get_or_create_sandbox, is_sandbox_enabled
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
raise HTTPException(status_code=404, detail="Sandbox is not enabled")
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to access files in this thread",
|
||||
)
|
||||
|
||||
from app.agents.new_chat.sandbox import get_local_sandbox_file
|
||||
|
||||
# Prefer locally-persisted copy (sandbox may already be deleted)
|
||||
local_content = get_local_sandbox_file(thread_id, path)
|
||||
if local_content is not None:
|
||||
filename = path.rsplit("/", 1)[-1] if "/" in path else path
|
||||
media_type = _guess_media_type(filename)
|
||||
return Response(
|
||||
content=local_content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
# Fall back to live sandbox download
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(thread_id)
|
||||
raw_sandbox = sandbox._sandbox
|
||||
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
||||
except Exception as exc:
|
||||
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Could not download file: {exc}"
|
||||
) from exc
|
||||
|
||||
filename = path.rsplit("/", 1)[-1] if "/" in path else path
|
||||
media_type = _guess_media_type(filename)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
|
@ -2735,7 +2735,10 @@ async def create_mcp_connector(
|
|||
f"for user {user.id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Convert to read schema
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
|
||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
|
|
@ -2910,6 +2913,10 @@ async def update_mcp_connector(
|
|||
|
||||
logger.info(f"Updated MCP connector {connector_id}")
|
||||
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
|
|
@ -2960,9 +2967,14 @@ async def delete_mcp_connector(
|
|||
"You don't have permission to delete this connector",
|
||||
)
|
||||
|
||||
search_space_id = connector.search_space_id
|
||||
await session.delete(connector)
|
||||
await session.commit()
|
||||
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
|
||||
logger.info(f"Deleted MCP connector {connector_id}")
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
|
|||
total_count: int
|
||||
|
||||
|
||||
class CommentBatchRequest(BaseModel):
|
||||
"""Request for batch-fetching comments for multiple messages."""
|
||||
|
||||
message_ids: list[int] = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class CommentBatchResponse(BaseModel):
|
||||
"""Batch response keyed by message_id."""
|
||||
|
||||
comments_by_message: dict[int, CommentListResponse]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mention Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Podcast schemas for API responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PodcastStatusEnum(str, Enum):
|
||||
class PodcastStatusEnum(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class SearchSourceConnectorBase(BaseModel):
|
|||
is_indexable: bool
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any]
|
||||
enable_summary: bool = False
|
||||
periodic_indexing_enabled: bool = False
|
||||
indexing_frequency_minutes: int | None = None
|
||||
next_scheduled_at: datetime | None = None
|
||||
|
|
@ -65,6 +66,7 @@ class SearchSourceConnectorUpdate(BaseModel):
|
|||
is_indexable: bool | None = None
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
enable_summary: bool | None = None
|
||||
periodic_indexing_enabled: bool | None = None
|
||||
indexing_frequency_minutes: int | None = None
|
||||
next_scheduled_at: datetime | None = None
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.chat_comments import (
|
||||
AuthorResponse,
|
||||
CommentBatchResponse,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
CommentResponse,
|
||||
|
|
@ -264,6 +265,146 @@ async def get_comments_for_message(
|
|||
)
|
||||
|
||||
|
||||
async def get_comments_for_messages_batch(
|
||||
session: AsyncSession,
|
||||
message_ids: list[int],
|
||||
user: User,
|
||||
) -> CommentBatchResponse:
|
||||
"""
|
||||
Batch-fetch comments for multiple messages in a single DB round-trip.
|
||||
|
||||
Validates that all messages exist and belong to search spaces the user
|
||||
can read comments in, then loads all comments with eager-loaded authors
|
||||
and replies.
|
||||
"""
|
||||
if not message_ids:
|
||||
return CommentBatchResponse(comments_by_message={})
|
||||
|
||||
unique_ids = list(set(message_ids))
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.thread))
|
||||
.filter(NewChatMessage.id.in_(unique_ids))
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
msg_map = {m.id: m for m in messages}
|
||||
|
||||
search_space_ids = {m.thread.search_space_id for m in messages}
|
||||
permissions_cache: dict[int, set] = {}
|
||||
for ss_id in search_space_ids:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
ss_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
)
|
||||
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
|
||||
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(
|
||||
selectinload(ChatComment.author),
|
||||
selectinload(ChatComment.replies).selectinload(ChatComment.author),
|
||||
)
|
||||
.filter(
|
||||
ChatComment.message_id.in_(unique_ids),
|
||||
ChatComment.parent_id.is_(None),
|
||||
)
|
||||
.order_by(ChatComment.created_at)
|
||||
)
|
||||
top_level_comments = result.scalars().all()
|
||||
|
||||
all_mentioned_uuids: set[UUID] = set()
|
||||
for comment in top_level_comments:
|
||||
all_mentioned_uuids.update(parse_mentions(comment.content))
|
||||
for reply in comment.replies:
|
||||
all_mentioned_uuids.update(parse_mentions(reply.content))
|
||||
|
||||
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
|
||||
|
||||
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
|
||||
for comment in top_level_comments:
|
||||
comments_by_msg.setdefault(comment.message_id, []).append(comment)
|
||||
|
||||
comments_by_message: dict[int, CommentListResponse] = {}
|
||||
for mid in unique_ids:
|
||||
msg = msg_map.get(mid)
|
||||
if msg is None:
|
||||
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
|
||||
continue
|
||||
|
||||
ss_id = msg.thread.search_space_id
|
||||
user_perms = permissions_cache.get(ss_id, set())
|
||||
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
|
||||
|
||||
comment_responses = []
|
||||
for comment in comments_by_msg.get(mid, []):
|
||||
author = None
|
||||
if comment.author:
|
||||
author = AuthorResponse(
|
||||
id=comment.author.id,
|
||||
display_name=comment.author.display_name,
|
||||
avatar_url=comment.author.avatar_url,
|
||||
email=comment.author.email,
|
||||
)
|
||||
|
||||
replies = []
|
||||
for reply in sorted(comment.replies, key=lambda r: r.created_at):
|
||||
reply_author = None
|
||||
if reply.author:
|
||||
reply_author = AuthorResponse(
|
||||
id=reply.author.id,
|
||||
display_name=reply.author.display_name,
|
||||
avatar_url=reply.author.avatar_url,
|
||||
email=reply.author.email,
|
||||
)
|
||||
is_reply_author = (
|
||||
reply.author_id == user.id if reply.author_id else False
|
||||
)
|
||||
replies.append(
|
||||
CommentReplyResponse(
|
||||
id=reply.id,
|
||||
content=reply.content,
|
||||
content_rendered=render_mentions(reply.content, user_names),
|
||||
author=reply_author,
|
||||
created_at=reply.created_at,
|
||||
updated_at=reply.updated_at,
|
||||
is_edited=reply.updated_at > reply.created_at,
|
||||
can_edit=is_reply_author,
|
||||
can_delete=is_reply_author or can_delete_any,
|
||||
)
|
||||
)
|
||||
|
||||
is_comment_author = (
|
||||
comment.author_id == user.id if comment.author_id else False
|
||||
)
|
||||
comment_responses.append(
|
||||
CommentResponse(
|
||||
id=comment.id,
|
||||
message_id=comment.message_id,
|
||||
content=comment.content,
|
||||
content_rendered=render_mentions(comment.content, user_names),
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
updated_at=comment.updated_at,
|
||||
is_edited=comment.updated_at > comment.created_at,
|
||||
can_edit=is_comment_author,
|
||||
can_delete=is_comment_author or can_delete_any,
|
||||
reply_count=len(replies),
|
||||
replies=replies,
|
||||
)
|
||||
)
|
||||
|
||||
comments_by_message[mid] = CommentListResponse(
|
||||
comments=comment_responses,
|
||||
total_count=len(comment_responses),
|
||||
)
|
||||
|
||||
return CommentBatchResponse(comments_by_message=comments_by_message)
|
||||
|
||||
|
||||
async def create_comment(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
|
@ -15,9 +16,11 @@ from app.db import (
|
|||
Document,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
async_session_maker,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
|
|
@ -221,6 +224,7 @@ class ConnectorService:
|
|||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform combined search using both chunk-based and document-based hybrid search,
|
||||
|
|
@ -246,34 +250,60 @@ class ConnectorService:
|
|||
Returns:
|
||||
List of combined and deduplicated document results
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# RRF constant
|
||||
k = 60
|
||||
|
||||
# Get more results from each retriever for better fusion
|
||||
retriever_top_k = top_k * 2
|
||||
|
||||
# IMPORTANT:
|
||||
# These retrievers share the same AsyncSession. AsyncSession does not permit
|
||||
# concurrent awaits that require DB IO on the same session/connection.
|
||||
# Running these in parallel can raise:
|
||||
# "This session is provisioning a new connection; concurrent operations are not permitted"
|
||||
#
|
||||
# So we run them sequentially.
|
||||
chunk_results = await self.chunk_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
# Reuse caller-provided embedding or compute once for both retrievers.
|
||||
if query_embedding is None:
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = config.embedding_model_instance.embed(query_text)
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
|
||||
time.perf_counter() - t_embed,
|
||||
document_type,
|
||||
)
|
||||
|
||||
search_kwargs = {
|
||||
"query_text": query_text,
|
||||
"top_k": retriever_top_k,
|
||||
"search_space_id": search_space_id,
|
||||
"document_type": document_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"query_embedding": query_embedding,
|
||||
}
|
||||
|
||||
# Run chunk and document retrievers in parallel using separate DB sessions
|
||||
# so they don't contend on a shared AsyncSession connection.
|
||||
async def _run_chunk_search() -> list[dict[str, Any]]:
|
||||
async with async_session_maker() as session:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
return await retriever.hybrid_search(**search_kwargs)
|
||||
|
||||
async def _run_doc_search() -> list[dict[str, Any]]:
|
||||
async with async_session_maker() as session:
|
||||
retriever = DocumentHybridSearchRetriever(session)
|
||||
return await retriever.hybrid_search(**search_kwargs)
|
||||
|
||||
t_parallel = time.perf_counter()
|
||||
chunk_results, doc_results = await asyncio.gather(
|
||||
_run_chunk_search(), _run_doc_search()
|
||||
)
|
||||
doc_results = await self.document_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf parallel retrievers in %.3fs "
|
||||
"chunk_results=%d doc_results=%d type=%s",
|
||||
time.perf_counter() - t_parallel,
|
||||
len(chunk_results),
|
||||
len(doc_results),
|
||||
document_type,
|
||||
)
|
||||
|
||||
# Helper to extract document_id from our doc-grouped result
|
||||
|
|
@ -335,6 +365,13 @@ class ConnectorService:
|
|||
result["chunks"] = doc_data[did]["chunks"]
|
||||
combined_results.append(result)
|
||||
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(combined_results),
|
||||
document_type,
|
||||
search_space_id,
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def _get_doc_url(self, metadata: dict[str, Any]) -> str:
|
||||
|
|
@ -1303,10 +1340,9 @@ class ConnectorService:
|
|||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
github_docs,
|
||||
description_fn=lambda chunk, _doc_info, metadata: metadata.get(
|
||||
"description"
|
||||
)
|
||||
or chunk.get("content", ""),
|
||||
description_fn=lambda chunk, _doc_info, metadata: (
|
||||
metadata.get("description") or chunk.get("content", "")
|
||||
),
|
||||
url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ from datetime import datetime
|
|||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Chunk, Document
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
|
@ -80,7 +80,7 @@ class LinearKBSyncService:
|
|||
state = formatted_issue.get("state", "Unknown")
|
||||
priority = issue_raw.get("priorityLabel", "Unknown")
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
description = formatted_issue.get("description", "")
|
||||
formatted_issue.get("description", "")
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
|
|
@ -100,18 +100,10 @@ class LinearKBSyncService:
|
|||
issue_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
if description and len(description) > 1000:
|
||||
description = description[:997] + "..."
|
||||
summary_content = (
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
|
||||
f"Status: {state}\n\n"
|
||||
)
|
||||
if description:
|
||||
summary_content += f"Description: {description}\n\n"
|
||||
summary_content += f"Comments: {comment_count}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
await self.db_session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
|
|
|
|||
|
|
@ -12,16 +12,42 @@ synchronous ChatLiteLLM-like interface and async methods.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from litellm import Router
|
||||
from litellm.exceptions import (
|
||||
BadRequestError as LiteLLMBadRequestError,
|
||||
ContextWindowExceededError,
|
||||
)
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
litellm.json_logs = False
|
||||
litellm.store_audit_logs = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
|
||||
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
|
||||
r"maximum context length|token.{0,20}(limit|exceed)|"
|
||||
r"too many tokens|reduce the length)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
||||
"""Check if a BadRequestError is actually a context window overflow."""
|
||||
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
|
@ -133,26 +159,95 @@ class LLMRouterService:
|
|||
# Merge with provided settings
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
# Build a "auto-large" fallback group with deployments whose context
|
||||
# window exceeds the smallest deployment. This lets the router
|
||||
# automatically fall back to a bigger-context model when gpt-4o (128K)
|
||||
# hits ContextWindowExceededError.
|
||||
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
router_kwargs: dict[str, Any] = {
|
||||
"model_list": full_model_list,
|
||||
"routing_strategy": final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False, # Disable verbose logging in production
|
||||
)
|
||||
"num_retries": final_settings.get("num_retries", 3),
|
||||
"allowed_fails": final_settings.get("allowed_fails", 3),
|
||||
"cooldown_time": final_settings.get("cooldown_time", 60),
|
||||
"set_verbose": False,
|
||||
}
|
||||
if ctx_fallbacks:
|
||||
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
||||
|
||||
instance._router = Router(**router_kwargs)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
f"LLM Router initialized with {len(model_list)} deployments, "
|
||||
f"strategy: {final_settings.get('routing_strategy')}"
|
||||
"LLM Router initialized with %d deployments, "
|
||||
"strategy: %s, context_window_fallbacks: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
ctx_fallbacks or "none",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _build_context_fallback_groups(
|
||||
cls, model_list: list[dict]
|
||||
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
|
||||
"""Create an ``auto-large`` model group for context-window fallbacks.
|
||||
|
||||
Uses ``litellm.get_model_info`` to discover the context window of each
|
||||
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
|
||||
window are duplicated into an ``auto-large`` group. The returned
|
||||
fallback config tells the Router: on ``ContextWindowExceededError`` for
|
||||
``auto``, retry with ``auto-large``.
|
||||
|
||||
Returns:
|
||||
(full_model_list, context_window_fallbacks) — ``full_model_list``
|
||||
contains the original entries plus any ``auto-large`` duplicates.
|
||||
``context_window_fallbacks`` is ``None`` when every deployment has
|
||||
the same context size (no useful fallback).
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
ctx_map: dict[str, int] = {}
|
||||
for dep in model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
ctx_map[base_model] = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not ctx_map:
|
||||
return model_list, None
|
||||
|
||||
min_ctx = min(ctx_map.values())
|
||||
|
||||
large_deployments: list[dict] = []
|
||||
for dep in model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
if ctx_map.get(base_model, 0) > min_ctx:
|
||||
dup = {**dep, "model_name": "auto-large"}
|
||||
large_deployments.append(dup)
|
||||
|
||||
if not large_deployments:
|
||||
return model_list, None
|
||||
|
||||
logger.info(
|
||||
"Context-window fallback: %d large-context deployments "
|
||||
"(min_ctx=%d) added to 'auto-large' group",
|
||||
len(large_deployments),
|
||||
min_ctx,
|
||||
)
|
||||
return model_list + large_deployments, [{"auto": ["auto-large"]}]
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
"""
|
||||
|
|
@ -228,12 +323,62 @@ class LLMRouterService:
|
|||
return len(instance._model_list)
|
||||
|
||||
|
||||
_cached_context_profile: dict | None = None
|
||||
_cached_context_profile_computed: bool = False
|
||||
|
||||
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
|
||||
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
|
||||
|
||||
|
||||
def _get_cached_context_profile(router: Router) -> dict | None:
|
||||
"""Compute and cache the min context profile across all router deployments.
|
||||
|
||||
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
||||
the cached value. This avoids calling litellm.get_model_info() for every
|
||||
deployment on every request.
|
||||
"""
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
if _cached_context_profile_computed:
|
||||
return _cached_context_profile
|
||||
|
||||
from litellm import get_model_info
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
|
||||
_cached_context_profile = {"max_input_tokens": min_ctx}
|
||||
else:
|
||||
_cached_context_profile = None
|
||||
|
||||
_cached_context_profile_computed = True
|
||||
return _cached_context_profile
|
||||
|
||||
|
||||
class ChatLiteLLMRouter(BaseChatModel):
|
||||
"""
|
||||
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
|
||||
|
||||
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
||||
making it a drop-in replacement for auto-mode routing.
|
||||
|
||||
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
||||
window across all router deployments so that deepagents
|
||||
SummarizationMiddleware can use fraction-based triggers.
|
||||
|
||||
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
|
||||
directly — instances without bound tools are cached per streaming flag to
|
||||
avoid per-request re-initialization overhead and memory growth.
|
||||
"""
|
||||
|
||||
# Use model_config for Pydantic v2 compatibility
|
||||
|
|
@ -255,17 +400,8 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
tool_choice: str | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the ChatLiteLLMRouter.
|
||||
|
||||
Args:
|
||||
router: LiteLLM Router instance. If None, uses the global singleton.
|
||||
bound_tools: Pre-bound tools for tool calling
|
||||
tool_choice: Tool choice configuration
|
||||
"""
|
||||
try:
|
||||
super().__init__(**kwargs)
|
||||
# Store router and tools as private attributes
|
||||
resolved_router = router or LLMRouterService.get_router()
|
||||
object.__setattr__(self, "_router", resolved_router)
|
||||
object.__setattr__(self, "_bound_tools", bound_tools)
|
||||
|
|
@ -274,8 +410,16 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
raise ValueError(
|
||||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||
)
|
||||
logger.info(
|
||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||
|
||||
computed_profile = _get_cached_context_profile(self._router)
|
||||
if computed_profile is not None:
|
||||
object.__setattr__(self, "profile", computed_profile)
|
||||
|
||||
logger.debug(
|
||||
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
|
||||
LLMRouterService.get_model_count(),
|
||||
self.streaming,
|
||||
bound_tools is not None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
|
|
@ -349,6 +493,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
msg_count = len(messages)
|
||||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
|
|
@ -359,12 +507,36 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router completion
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
try:
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
perf.info(
|
||||
"[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
|
||||
msg_count,
|
||||
len(self._bound_tools) if self._bound_tools else 0,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
|
|
@ -386,6 +558,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
msg_count = len(messages)
|
||||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
|
|
@ -396,12 +572,36 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router async completion
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
try:
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
perf.info(
|
||||
"[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
|
||||
msg_count,
|
||||
len(self._bound_tools) if self._bound_tools else 0,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
|
|
@ -432,14 +632,20 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router completion with streaming
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
try:
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Yield chunks
|
||||
for chunk in response:
|
||||
|
|
@ -462,6 +668,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
msg_count = len(messages)
|
||||
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
# Add tools if bound
|
||||
|
|
@ -471,23 +681,61 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router async completion with streaming
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
try:
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
t_first_chunk = time.perf_counter()
|
||||
perf.info(
|
||||
"[llm_router] _astream connection established msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
t_first_chunk - t0,
|
||||
)
|
||||
|
||||
# Yield chunks asynchronously
|
||||
chunk_count = 0
|
||||
first_chunk_logged = False
|
||||
async for chunk in response:
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
chunk_msg = self._convert_delta_to_chunk(delta)
|
||||
if chunk_msg:
|
||||
chunk_count += 1
|
||||
if not first_chunk_logged:
|
||||
perf.info(
|
||||
"[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
|
||||
time.perf_counter() - t_first_chunk,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
first_chunk_logged = True
|
||||
yield ChatGenerationChunk(message=chunk_msg)
|
||||
|
||||
perf.info(
|
||||
"[llm_router] _astream completed chunks=%d total=%.3fs",
|
||||
chunk_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
|
||||
"""Convert LangChain messages to OpenAI format."""
|
||||
from langchain_core.messages import (
|
||||
|
|
@ -602,19 +850,28 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
return None
|
||||
|
||||
|
||||
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Get a ChatLiteLLMRouter instance for auto mode.
|
||||
def get_auto_mode_llm(
|
||||
*,
|
||||
streaming: bool = True,
|
||||
) -> ChatLiteLLMRouter | None:
|
||||
"""Return a cached ChatLiteLLMRouter for auto mode.
|
||||
|
||||
Returns:
|
||||
ChatLiteLLMRouter instance or None if router not initialized
|
||||
Base (no tools) instances are cached per ``streaming`` flag so we
|
||||
avoid re-constructing them on every request. ``bind_tools()`` still
|
||||
returns a fresh instance because bound tools differ per agent.
|
||||
"""
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.warning("LLM Router not initialized for auto mode")
|
||||
return None
|
||||
|
||||
cached = _router_instance_cache.get(streaming)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
instance = ChatLiteLLMRouter(streaming=streaming)
|
||||
_router_instance_cache[streaming] = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -12,12 +12,20 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
litellm.drop_params = True
|
||||
|
||||
# Memory controls: prevent unbounded internal accumulation
|
||||
litellm.telemetry = False
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm.input_callback = []
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -221,7 +229,7 @@ async def get_search_space_llm_instance(
|
|||
logger.debug(
|
||||
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||
)
|
||||
return ChatLiteLLMRouter(disable_streaming=disable_streaming)
|
||||
return get_auto_mode_llm(streaming=not disable_streaming)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ from datetime import datetime
|
|||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
|
@ -127,10 +127,8 @@ class NotionKBSyncService:
|
|||
logger.debug(f"Generated summary length: {len(summary_content)} chars")
|
||||
else:
|
||||
logger.warning("No LLM configured - using fallback summary")
|
||||
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content[:500]}..."
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
logger.debug(f"Deleting old chunks for document {document_id}")
|
||||
await self.db_session.execute(
|
||||
|
|
|
|||
53
surfsense_backend/app/services/task_dispatcher.py
Normal file
53
surfsense_backend/app/services/task_dispatcher.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Task dispatcher abstraction for background document processing.
|
||||
|
||||
Decouples the upload endpoint from Celery so tests can swap in a
|
||||
synchronous (inline) implementation that needs only PostgreSQL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TaskDispatcher(Protocol):
|
||||
async def dispatch_file_processing(
|
||||
self,
|
||||
*,
|
||||
document_id: int,
|
||||
temp_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class CeleryTaskDispatcher:
|
||||
"""Production dispatcher — fires Celery tasks via Redis broker."""
|
||||
|
||||
async def dispatch_file_processing(
|
||||
self,
|
||||
*,
|
||||
document_id: int,
|
||||
temp_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
) -> None:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
process_file_upload_with_document_task,
|
||||
)
|
||||
|
||||
process_file_upload_with_document_task.delay(
|
||||
document_id=document_id,
|
||||
temp_path=temp_path,
|
||||
filename=filename,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
|
||||
async def get_task_dispatcher() -> TaskDispatcher:
|
||||
return CeleryTaskDispatcher()
|
||||
|
|
@ -1 +1,28 @@
|
|||
"""Celery tasks package."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
_celery_engine = None
|
||||
_celery_session_maker = None
|
||||
|
||||
|
||||
def get_celery_session_maker() -> async_sessionmaker:
|
||||
"""Return a shared async session maker for Celery tasks.
|
||||
|
||||
A single NullPool engine is created per worker process and reused
|
||||
across all task invocations to avoid leaking engine objects.
|
||||
"""
|
||||
global _celery_engine, _celery_session_maker
|
||||
if _celery_session_maker is None:
|
||||
_celery_engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
_celery_session_maker = async_sessionmaker(
|
||||
_celery_engine, expire_on_commit=False
|
||||
)
|
||||
return _celery_session_maker
|
||||
|
|
|
|||
|
|
@ -3,11 +3,8 @@
|
|||
import logging
|
||||
import traceback
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
|
|||
)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
This is necessary because Celery tasks run in a new event loop,
|
||||
and the default session maker is bound to the main app's event loop.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="index_slack_messages", bind=True)
|
||||
def index_slack_messages_task(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -2,35 +2,20 @@
|
|||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Document
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
generate_document_summary,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""Create async session maker for Celery tasks."""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="reindex_document", bind=True)
|
||||
def reindex_document_task(self, document_id: int, user_id: str):
|
||||
"""
|
||||
|
|
@ -54,7 +39,6 @@ def reindex_document_task(self, document_id: int, user_id: str):
|
|||
async def _reindex_document(document_id: int, user_id: str):
|
||||
"""Async function to reindex a document."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
# First, get the document to get search_space_id for logging
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
|
|
@ -66,10 +50,8 @@ async def _reindex_document(document_id: int, user_id: str):
|
|||
logger.error(f"Document {document_id} not found")
|
||||
return
|
||||
|
||||
# Initialize task logger
|
||||
task_logger = TaskLoggingService(session, document.search_space_id)
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="document_reindex",
|
||||
source="editor",
|
||||
|
|
@ -83,10 +65,7 @@ async def _reindex_document(document_id: int, user_id: str):
|
|||
)
|
||||
|
||||
try:
|
||||
# Read markdown directly from source_markdown
|
||||
markdown_content = document.source_markdown
|
||||
|
||||
if not markdown_content:
|
||||
if not document.source_markdown:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Document {document_id} has no source_markdown to reindex",
|
||||
|
|
@ -97,51 +76,17 @@ async def _reindex_document(document_id: int, user_id: str):
|
|||
|
||||
logger.info(f"Reindexing document {document_id} ({document.title})")
|
||||
|
||||
# 1. Delete old chunks explicitly
|
||||
from app.db import Chunk
|
||||
|
||||
await session.execute(delete(Chunk).where(Chunk.document_id == document_id))
|
||||
await session.flush() # Ensure old chunks are deleted
|
||||
|
||||
# 2. Create new chunks from source_markdown
|
||||
new_chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
# 3. Add new chunks to session
|
||||
for chunk in new_chunks:
|
||||
chunk.document_id = document_id
|
||||
session.add(chunk)
|
||||
|
||||
logger.info(f"Created {len(new_chunks)} chunks for document {document_id}")
|
||||
|
||||
# 4. Regenerate summary
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, document.search_space_id
|
||||
)
|
||||
|
||||
document_metadata = {
|
||||
"title": document.title,
|
||||
"document_type": document.document_type.value,
|
||||
}
|
||||
adapter = UploadDocumentAdapter(session)
|
||||
await adapter.reindex(document=document, llm=user_llm)
|
||||
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
markdown_content, user_llm, document_metadata
|
||||
)
|
||||
|
||||
# 5. Update document
|
||||
document.content = summary_content
|
||||
document.embedding = summary_embedding
|
||||
document.content_needs_reindexing = False
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully reindexed document: {document.title}",
|
||||
{
|
||||
"chunks_created": len(new_chunks),
|
||||
"document_id": document_id,
|
||||
},
|
||||
{"document_id": document_id},
|
||||
)
|
||||
|
||||
logger.info(f"Successfully reindexed document {document_id}")
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@ import logging
|
|||
import os
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.document_processors import (
|
||||
add_extension_received_document,
|
||||
add_youtube_video_document,
|
||||
|
|
@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int):
|
|||
pass # Normal cancellation when task completes
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
This is necessary because Celery tasks run in a new event loop,
|
||||
and the default session maker is bound to the main app's event loop.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="process_extension_document", bind=True)
|
||||
def process_extension_document_task(
|
||||
self, individual_document_dict, search_space_id: int, user_id: str
|
||||
|
|
@ -626,6 +610,7 @@ def process_file_upload_with_document_task(
|
|||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
):
|
||||
"""
|
||||
Celery task to process uploaded file with existing pending document.
|
||||
|
|
@ -640,6 +625,7 @@ def process_file_upload_with_document_task(
|
|||
filename: Original filename
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
should_summarize: Whether to generate an LLM summary
|
||||
"""
|
||||
import traceback
|
||||
|
||||
|
|
@ -674,7 +660,12 @@ def process_file_upload_with_document_task(
|
|||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_with_document(
|
||||
document_id, temp_path, filename, search_space_id, user_id
|
||||
document_id,
|
||||
temp_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
user_id,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
|
|
@ -710,6 +701,7 @@ async def _process_file_with_document(
|
|||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
):
|
||||
"""
|
||||
Process file and update existing pending document status.
|
||||
|
|
@ -811,6 +803,7 @@ async def _process_file_with_document(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
notification=notification,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
|
|
|
|||
|
|
@ -5,14 +5,13 @@ import logging
|
|||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,20 +24,6 @@ if sys.platform.startswith("win"):
|
|||
)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
This is necessary because Celery tasks run in a new event loop,
|
||||
and the default session maker is bound to the main app's event loop.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content-based podcast generation (for new-chat)
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -3,28 +3,16 @@
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.utils.indexing_locks import is_connector_indexing_locked
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""Create async session maker for Celery tasks."""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="check_periodic_schedules")
|
||||
def check_periodic_schedules_task():
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -29,20 +29,17 @@ from datetime import UTC, datetime
|
|||
|
||||
import redis
|
||||
from sqlalchemy import and_, or_, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentStatus, Notification
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client for checking heartbeats
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
# Error messages shown to users when tasks are interrupted
|
||||
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
|
||||
STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry."
|
||||
|
||||
|
|
@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str:
|
|||
return f"indexing:heartbeat:{notification_id}"
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""Create async session maker for Celery tasks."""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="cleanup_stale_indexing_notifications")
|
||||
def cleanup_stale_indexing_notifications_task():
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,17 +9,23 @@ Supports loading LLM configurations from:
|
|||
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
|
||||
import anyio
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
|
|
@ -30,7 +36,21 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker
|
||||
from app.agents.new_chat.sandbox import (
|
||||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
Document,
|
||||
NewChatMessage,
|
||||
NewChatThread,
|
||||
Report,
|
||||
SearchSourceConnectorType,
|
||||
SurfsenseDocsDocument,
|
||||
async_session_maker,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
|
|
@ -39,6 +59,11 @@ from app.services.chat_session_state_service import (
|
|||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
||||
|
|
@ -187,6 +212,7 @@ class StreamResult:
|
|||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -404,6 +430,21 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
cmd = (
|
||||
tool_input.get("command", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
||||
last_active_step_title = "Running command"
|
||||
last_active_step_items = [f"$ {display_cmd}"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Running command",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
else:
|
||||
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
|
||||
last_active_step_items = []
|
||||
|
|
@ -620,6 +661,32 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
exit_code_val = int(m.group(1)) if m else None
|
||||
if exit_code_val is not None and exit_code_val == 0:
|
||||
completed_items = [
|
||||
*last_active_step_items,
|
||||
"Completed successfully",
|
||||
]
|
||||
elif exit_code_val is not None:
|
||||
completed_items = [
|
||||
*last_active_step_items,
|
||||
f"Exit code: {exit_code_val}",
|
||||
]
|
||||
else:
|
||||
completed_items = [*last_active_step_items, "Finished"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Running command",
|
||||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "ls":
|
||||
if isinstance(tool_output, dict):
|
||||
ls_output = tool_output.get("result", "")
|
||||
|
|
@ -813,6 +880,36 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
exit_code: int | None = None
|
||||
output_text = raw_text
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
if m:
|
||||
exit_code = int(m.group(1))
|
||||
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
|
||||
output_text = om.group(1) if om else ""
|
||||
thread_id_str = config.get("configurable", {}).get("thread_id", "")
|
||||
|
||||
for sf_match in re.finditer(
|
||||
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
|
||||
):
|
||||
fpath = sf_match.group(1).strip()
|
||||
if fpath and fpath not in result.sandbox_files:
|
||||
result.sandbox_files.append(fpath)
|
||||
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{
|
||||
"exit_code": exit_code,
|
||||
"output": output_text,
|
||||
"thread_id": thread_id_str,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
@ -881,11 +978,42 @@ async def _stream_agent_events(
|
|||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||
|
||||
|
||||
def _try_persist_and_delete_sandbox(
|
||||
thread_id: int,
|
||||
sandbox_files: list[str],
|
||||
) -> None:
|
||||
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
return
|
||||
|
||||
async def _run() -> None:
|
||||
try:
|
||||
await persist_and_delete_sandbox(thread_id, sandbox_files)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"persist_and_delete_sandbox failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(_run())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
chat_id: int,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
|
|
@ -901,11 +1029,13 @@ async def stream_new_chat(
|
|||
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
|
||||
|
||||
The function creates and manages its own database session to guarantee proper
|
||||
cleanup even when Starlette's middleware cancels the task on client disconnect.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID
|
||||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||
session: The database session
|
||||
user_id: The current user's UUID string (for memory tools and session state)
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||
|
|
@ -917,7 +1047,11 @@ async def stream_new_chat(
|
|||
str: SSE formatted response strings
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
if user_id:
|
||||
|
|
@ -925,6 +1059,7 @@ async def stream_new_chat(
|
|||
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
|
||||
agent_config: AgentConfig | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
if llm_config_id >= 0:
|
||||
# Positive ID: Load from NewLLMConfig database table
|
||||
agent_config = await load_agent_config(
|
||||
|
|
@ -955,6 +1090,11 @@ async def stream_new_chat(
|
|||
llm = create_chat_litellm_from_config(llm_config)
|
||||
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
llm_config_id,
|
||||
)
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -962,22 +1102,45 @@ async def stream_new_chat(
|
|||
return
|
||||
|
||||
# Create connector service
|
||||
_t0 = time.perf_counter()
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
# Get Firecrawl API key from webcrawler connector if configured
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
firecrawl_api_key = None
|
||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||
)
|
||||
if webcrawler_connector and webcrawler_connector.config:
|
||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
sandbox_backend = None
|
||||
_t0 = time.perf_counter()
|
||||
if is_sandbox_enabled():
|
||||
try:
|
||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||
except Exception as sandbox_err:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Sandbox creation failed, continuing without execute tool: %s",
|
||||
sandbox_err,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
sandbox_backend is not None,
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -989,20 +1152,22 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
sandbox_backend=sandbox_backend,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Build input with message history
|
||||
langchain_messages = []
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=visibility
|
||||
)
|
||||
|
||||
# Clear the flag so we don't bootstrap again on next message
|
||||
from app.db import NewChatThread
|
||||
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
|
|
@ -1014,11 +1179,9 @@ async def stream_new_chat(
|
|||
# Fetch mentioned documents if any (with chunks for proper citations)
|
||||
mentioned_documents: list[Document] = []
|
||||
if mentioned_document_ids:
|
||||
from sqlalchemy.orm import selectinload as doc_selectinload
|
||||
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(doc_selectinload(Document.chunks))
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(
|
||||
Document.id.in_(mentioned_document_ids),
|
||||
Document.search_space_id == search_space_id,
|
||||
|
|
@ -1029,8 +1192,6 @@ async def stream_new_chat(
|
|||
# Fetch mentioned SurfSense docs if any
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||
if mentioned_surfsense_doc_ids:
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await session.execute(
|
||||
select(SurfsenseDocsDocument)
|
||||
.options(selectinload(SurfsenseDocsDocument.chunks))
|
||||
|
|
@ -1114,6 +1275,11 @@ async def stream_new_chat(
|
|||
"search_space_id": search_space_id,
|
||||
}
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# All pre-streaming DB reads are done. Commit to release the
|
||||
# transaction and its ACCESS SHARE locks so we don't block DDL
|
||||
# (e.g. migrations) for the entire duration of LLM streaming.
|
||||
|
|
@ -1121,6 +1287,18 @@ async def stream_new_chat(
|
|||
# short-lived transactions (or use isolated sessions).
|
||||
await session.commit()
|
||||
|
||||
# Detach heavy ORM objects (documents with chunks, reports, etc.)
|
||||
# from the session identity map now that we've extracted the data
|
||||
# we need. This prevents them from accumulating in memory for the
|
||||
# entire duration of LLM streaming (which can be several minutes).
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
# Configure LangGraph with thread_id for memory
|
||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||
configurable = {"thread_id": str(chat_id)}
|
||||
|
|
@ -1182,7 +1360,14 @@ async def stream_new_chat(
|
|||
items=initial_items,
|
||||
)
|
||||
|
||||
stream_result = StreamResult()
|
||||
# These ORM objects (with eagerly-loaded chunks) can be very large.
|
||||
# They're only needed to build context strings already copied into
|
||||
# final_query / langchain_messages — release them before streaming.
|
||||
del mentioned_documents, mentioned_surfsense_docs, recent_reports
|
||||
del langchain_messages, final_query
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
|
|
@ -1194,8 +1379,24 @@ async def stream_new_chat(
|
|||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1204,12 +1405,6 @@ async def stream_new_chat(
|
|||
|
||||
accumulated_text = stream_result.accumulated_text
|
||||
|
||||
# Generate LLM title for new chats after first response
|
||||
# Check if this is the first assistant response by counting existing assistant messages
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
|
|
@ -1280,39 +1475,72 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
# Clear AI responding state for live collaboration.
|
||||
# The original session may be broken (client disconnect / CancelledError
|
||||
# can corrupt the underlying DB connection), so we try a rollback first
|
||||
# and fall back to a fresh session if the original is unusable.
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
# Shield the ENTIRE async cleanup from anyio cancel-scope
|
||||
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
|
||||
# groups; on client disconnect, it cancels the scope with
|
||||
# level-triggered cancellation — every unshielded `await` inside
|
||||
# the cancelled scope raises CancelledError immediately. Without
|
||||
# this shield the very first `await` (session.rollback) would
|
||||
# raise CancelledError, `except Exception` wouldn't catch it
|
||||
# (CancelledError is a BaseException), and the rest of the
|
||||
# finally block — including session.close() — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = sandbox_backend = None
|
||||
input_state = stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
|
||||
async def stream_resume_chat(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
decisions: list[dict],
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
if user_id:
|
||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
||||
agent_config: AgentConfig | None = None
|
||||
_t0 = time.perf_counter()
|
||||
if llm_config_id >= 0:
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
|
|
@ -1336,26 +1564,54 @@ async def stream_resume_chat(
|
|||
return
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
_perf_log.info(
|
||||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
firecrawl_api_key = None
|
||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||
)
|
||||
if webcrawler_connector and webcrawler_connector.config:
|
||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
_perf_log.info(
|
||||
"[stream_resume] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
sandbox_backend = None
|
||||
_t0 = time.perf_counter()
|
||||
if is_sandbox_enabled():
|
||||
try:
|
||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||
except Exception as sandbox_err:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Sandbox creation failed, continuing without execute tool: %s",
|
||||
sandbox_err,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
sandbox_backend is not None,
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -1367,10 +1623,21 @@ async def stream_resume_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
sandbox_backend=sandbox_backend,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||
await session.commit()
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
|
|
@ -1382,7 +1649,8 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
stream_result = StreamResult()
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
|
|
@ -1391,7 +1659,20 @@ async def stream_resume_chat(
|
|||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
if stream_result.is_interrupted:
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1414,14 +1695,37 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
agent = llm = connector_service = sandbox_backend = None
|
||||
stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot("stream_resume_chat_END")
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from collections.abc import Awaitable, Callable
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.airtable_history import AirtableHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -399,7 +399,7 @@ async def index_airtable_records(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"record_id": item["record_id"],
|
||||
"created_time": item["record"].get("CREATED_TIME()", ""),
|
||||
|
|
@ -415,11 +415,8 @@ async def index_airtable_records(
|
|||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Airtable Record: {item['record_id']}\n\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Airtable Record: {item['record_id']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.bookstack_connector import BookStackConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -403,7 +403,7 @@ async def index_bookstack_pages(
|
|||
"connector_id": connector_id,
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
summary_metadata = {
|
||||
"page_name": item["page_name"],
|
||||
"page_id": item["page_id"],
|
||||
|
|
@ -418,17 +418,8 @@ async def index_bookstack_pages(
|
|||
item["full_content"], user_llm, summary_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n"
|
||||
if item["page_content"]:
|
||||
# Take first 1000 characters of content for summary
|
||||
content_preview = item["page_content"][:1000]
|
||||
if len(item["page_content"]) > 1000:
|
||||
content_preview += "..."
|
||||
summary_content += f"Content Preview: {content_preview}\n\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n{item['full_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full page content
|
||||
chunks = await create_document_chunks(item["full_content"])
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.clickup_history import ClickUpHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -398,7 +398,7 @@ async def index_clickup_tasks(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"task_id": item["task_id"],
|
||||
"task_name": item["task_name"],
|
||||
|
|
@ -418,9 +418,7 @@ async def index_clickup_tasks(
|
|||
)
|
||||
else:
|
||||
summary_content = item["task_content"]
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
item["task_content"]
|
||||
)
|
||||
summary_embedding = embed_text(item["task_content"])
|
||||
|
||||
chunks = await create_document_chunks(item["task_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -378,7 +378,7 @@ async def index_confluence_pages(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata = {
|
||||
"page_title": item["page_title"],
|
||||
"page_id": item["page_id"],
|
||||
|
|
@ -394,18 +394,8 @@ async def index_confluence_pages(
|
|||
item["full_content"], user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n"
|
||||
if item["page_content"]:
|
||||
# Take first 1000 characters of content for summary
|
||||
content_preview = item["page_content"][:1000]
|
||||
if len(item["page_content"]) > 1000:
|
||||
content_preview += "..."
|
||||
summary_content += f"Content Preview: {content_preview}\n\n"
|
||||
summary_content += f"Comments: {item['comment_count']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full page content with comments
|
||||
chunks = await create_document_chunks(item["full_content"])
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnector
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
|
@ -669,9 +670,7 @@ async def index_discord_messages(
|
|||
|
||||
# Heavy processing (embeddings, chunks)
|
||||
chunks = await create_document_chunks(item["combined_document_string"])
|
||||
doc_embedding = config.embedding_model_instance.embed(
|
||||
item["combined_document_string"]
|
||||
)
|
||||
doc_embedding = embed_text(item["combined_document_string"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = f"{item['guild_name']}#{item['channel_name']}"
|
||||
|
|
|
|||
|
|
@ -16,13 +16,13 @@ from datetime import UTC, datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -367,7 +367,7 @@ async def index_github_repos(
|
|||
"estimated_tokens": digest.estimated_tokens,
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
# Prepare content for summarization
|
||||
summary_content = digest.full_digest
|
||||
if len(summary_content) > MAX_DIGEST_CHARS:
|
||||
|
|
@ -381,15 +381,12 @@ async def index_github_repos(
|
|||
summary_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_text = (
|
||||
f"# GitHub Repository: {repo_full_name}\n\n"
|
||||
f"## Summary\n{digest.summary}\n\n"
|
||||
f"## File Structure\n{digest.tree[:3000]}"
|
||||
)
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_text
|
||||
f"## File Structure\n{digest.tree}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_text)
|
||||
|
||||
# Chunk the full digest content for granular search
|
||||
try:
|
||||
|
|
@ -551,7 +548,7 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list:
|
|||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk_text,
|
||||
embedding=config.embedding_model_instance.embed(chunk_text),
|
||||
embedding=embed_text(chunk_text),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from app.services.llm_service import get_user_long_context_llm
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -489,7 +490,7 @@ async def index_google_calendar_events(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"event_summary": item["event_summary"],
|
||||
|
|
@ -507,22 +508,8 @@ async def index_google_calendar_events(
|
|||
item["event_markdown"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {item['event_summary']}\n\n"
|
||||
)
|
||||
summary_content += f"Calendar: {item['calendar_id']}\n"
|
||||
summary_content += f"Start: {item['start_time']}\n"
|
||||
summary_content += f"End: {item['end_time']}\n"
|
||||
if item["location"]:
|
||||
summary_content += f"Location: {item['location']}\n"
|
||||
if item["description"]:
|
||||
desc_preview = item["description"][:1000]
|
||||
if len(item["description"]) > 1000:
|
||||
desc_preview += "..."
|
||||
summary_content += f"Description: {desc_preview}\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["event_markdown"])
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue