mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 20:03:30 +02:00
commit
753e5b56cb
282 changed files with 23802 additions and 9409 deletions
40
.github/workflows/desktop-release.yml
vendored
40
.github/workflows/desktop-release.yml
vendored
|
|
@ -5,6 +5,20 @@ on:
|
|||
tags:
|
||||
- 'v*'
|
||||
- 'beta-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version number (e.g. 0.0.15) — used for dry-run testing without a tag'
|
||||
required: true
|
||||
default: '0.0.0-test'
|
||||
publish:
|
||||
description: 'Publish to GitHub Releases'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- never
|
||||
- always
|
||||
default: 'never'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
|
@ -25,24 +39,28 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Extract version from tag
|
||||
- name: Extract version
|
||||
id: version
|
||||
shell: bash
|
||||
run: |
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
VERSION=${TAG#beta-}
|
||||
VERSION=${VERSION#v}
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
VERSION="${{ inputs.version }}"
|
||||
else
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
VERSION=${TAG#beta-}
|
||||
VERSION=${VERSION#v}
|
||||
fi
|
||||
echo "VERSION=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
uses: pnpm/action-setup@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v5
|
||||
with:
|
||||
node-version: 20
|
||||
node-version: 22
|
||||
cache: 'pnpm'
|
||||
cache-dependency-path: |
|
||||
surfsense_web/pnpm-lock.yaml
|
||||
|
|
@ -60,6 +78,7 @@ jobs:
|
|||
NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }}
|
||||
NEXT_PUBLIC_POSTHOG_KEY: ${{ secrets.NEXT_PUBLIC_POSTHOG_KEY }}
|
||||
|
||||
- name: Install desktop dependencies
|
||||
run: pnpm install
|
||||
|
|
@ -70,9 +89,12 @@ jobs:
|
|||
working-directory: surfsense_desktop
|
||||
env:
|
||||
HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }}
|
||||
POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }}
|
||||
POSTHOG_HOST: ${{ vars.POSTHOG_HOST }}
|
||||
|
||||
- name: Package & Publish
|
||||
run: pnpm exec electron-builder ${{ matrix.platform }} --config electron-builder.yml --publish always -c.extraMetadata.version=${{ steps.version.outputs.VERSION }}
|
||||
shell: bash
|
||||
run: pnpm exec electron-builder ${{ matrix.platform }} --config electron-builder.yml --publish ${{ inputs.publish || 'always' }} -c.extraMetadata.version=${{ steps.version.outputs.VERSION }}
|
||||
working-directory: surfsense_desktop
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
62
README.es.md
62
README.es.md
|
|
@ -41,18 +41,14 @@ NotebookLM es una de las mejores y más útiles plataformas de IA que existen, p
|
|||
- **Sin Dependencia de Proveedores** - Configura cualquier modelo LLM, de imagen, TTS y STT.
|
||||
- **25+ Fuentes de Datos Externas** - Agrega tus fuentes desde Google Drive, OneDrive, Dropbox, Notion y muchos otros servicios externos.
|
||||
- **Soporte Multijugador en Tiempo Real** - Trabaja fácilmente con los miembros de tu equipo en un notebook compartido.
|
||||
- **Aplicación de Escritorio** - Obtén asistencia de IA en cualquier aplicación con Quick Assist, General Assist, Extreme Assist y sincronización de carpetas locales.
|
||||
|
||||
...y más por venir.
|
||||
|
||||
|
||||
|
||||
# Demo
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## Ejemplo de Agente de Video
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/012a7ffa-6f76-4f06-9dda-7632b470057a
|
||||
|
||||
|
||||
|
|
@ -68,42 +64,58 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
1. Ve a [surfsense.com](https://www.surfsense.com) e inicia sesión.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/b4df25fe-db5a-43c2-9462-b75cf7f1b707" alt="Login" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/LoginFlowGif.gif" alt="Login" /></p>
|
||||
|
||||
2. Conecta tus conectores y sincroniza. Activa la sincronización periódica para mantenerlos actualizados.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="Conectores" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ConnectorFlowGif.gif" alt="Conectores" /></p>
|
||||
|
||||
3. Mientras se indexan los datos de los conectores, sube documentos.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="Subir Documentos" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/DocUploadGif.gif" alt="Subir Documentos" /></p>
|
||||
|
||||
4. Una vez que todo esté indexado, pregunta lo que quieras (Casos de uso):
|
||||
|
||||
- Aplicación de Escritorio — General Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/general_assist.gif" alt="General Assist" /></p>
|
||||
|
||||
- Aplicación de Escritorio — Quick Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||
|
||||
- Aplicación de Escritorio — Extreme Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
||||
|
||||
- Aplicación de Escritorio — Watch Local Folder
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/folder_watch.gif" alt="Watch Local Folder" /></p>
|
||||
|
||||
- Generación de videos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="Generación de Videos" /></p>
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/video_gen_gif.gif" alt="Generación de Videos" /></p>
|
||||
|
||||
- Búsqueda básica y citaciones
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="Búsqueda y Citación" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BSNCGif.gif" alt="Búsqueda y Citación" /></p>
|
||||
|
||||
- QNA con mención de documentos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="QNA con Mención de Documentos" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="QNA con Mención de Documentos" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="QNA con Mención de Documentos" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="QNA con Mención de Documentos" /></p>
|
||||
|
||||
- Generación de informes y exportaciones (PDF, DOCX, HTML, LaTeX, EPUB, ODT, texto plano)
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/9836b7d6-57c9-4951-b61c-68202c9b6ace" alt="Generación de Informes" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ReportGenGif_compressed.gif" alt="Generación de Informes" /></p>
|
||||
|
||||
- Generación de podcasts
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/58c9b057-8848-4e81-aaba-d2c617985d8c" alt="Generación de Podcasts" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/PodcastGenGif.gif" alt="Generación de Podcasts" /></p>
|
||||
|
||||
- Generación de imágenes
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/25f94cb3-18f8-4854-afd9-27b7bfd079cb" alt="Generación de Imágenes" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ImageGenGif.gif" alt="Generación de Imágenes" /></p>
|
||||
|
||||
- Y más próximamente.
|
||||
|
||||
|
|
@ -130,6 +142,19 @@ El script de instalación configura [Watchtower](https://github.com/nicholas-fed
|
|||
|
||||
Para Docker Compose, instalación manual y otras opciones de despliegue, consulta la [documentación](https://www.surfsense.com/docs/).
|
||||
|
||||
### Aplicación de Escritorio
|
||||
|
||||
SurfSense también ofrece una aplicación de escritorio que lleva la asistencia de IA a cada aplicación en tu computadora. Descárgala desde la [última versión](https://github.com/MODSetter/SurfSense/releases/latest).
|
||||
|
||||
La aplicación de escritorio incluye estas potentes funciones:
|
||||
|
||||
- **General Assist** — Lanza SurfSense al instante desde cualquier aplicación con un atajo global.
|
||||
- **Quick Assist** — Selecciona texto en cualquier lugar, luego pide a la IA que lo explique, reescriba o actúe sobre él.
|
||||
- **Extreme Assist** — Obtén sugerencias de escritura en línea impulsadas por tu base de conocimiento mientras escribes en cualquier aplicación.
|
||||
- **Watch Local Folder** — Vigila una carpeta local y sincroniza automáticamente los cambios de archivos con tu base de conocimiento. **Pro tip:** Apúntalo a tu bóveda de Obsidian para mantener tus notas buscables en SurfSense.
|
||||
|
||||
Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tus respuestas siempre están basadas en tus propios datos.
|
||||
|
||||
### Cómo Colaborar en Tiempo Real (Beta)
|
||||
|
||||
1. Ve a la página de Gestión de Miembros y crea una invitación.
|
||||
|
|
@ -146,11 +171,11 @@ Para Docker Compose, instalación manual y otras opciones de despliegue, consult
|
|||
|
||||
4. Tu equipo ahora puede chatear en tiempo real.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/83803ac2-fbce-4d93-aae3-85eb85a3053a" alt="Chat en Tiempo Real" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeChatGif.gif" alt="Chat en Tiempo Real" /></p>
|
||||
|
||||
5. Agrega comentarios para etiquetar a compañeros de equipo.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/3b04477d-8f42-4baa-be95-867c1eaeba87" alt="Comentarios en Tiempo Real" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeCommentsFlow.gif" alt="Comentarios en Tiempo Real" /></p>
|
||||
|
||||
## SurfSense vs Google NotebookLM
|
||||
|
||||
|
|
@ -174,6 +199,7 @@ Para Docker Compose, instalación manual y otras opciones de despliegue, consult
|
|||
| **Generación de Videos** | Resúmenes en video cinemáticos vía Veo 3 (solo Ultra) | Disponible (NotebookLM es mejor aquí, mejorando activamente) |
|
||||
| **Generación de Presentaciones** | Diapositivas más atractivas pero no editables | Crea presentaciones editables basadas en diapositivas |
|
||||
| **Generación de Podcasts** | Resúmenes de audio con hosts e idiomas personalizables | Disponible con múltiples proveedores TTS (NotebookLM es mejor aquí, mejorando activamente) |
|
||||
| **Aplicación de Escritorio** | No | Aplicación nativa con General Assist, Quick Assist, Extreme Assist y sincronización de carpetas locales |
|
||||
| **Extensión de Navegador** | No | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación |
|
||||
|
||||
<details>
|
||||
|
|
|
|||
62
README.hi.md
62
README.hi.md
|
|
@ -41,18 +41,14 @@ NotebookLM वहाँ उपलब्ध सबसे अच्छे और
|
|||
- **कोई विक्रेता लॉक-इन नहीं** - किसी भी LLM, इमेज, TTS और STT मॉडल को कॉन्फ़िगर करें।
|
||||
- **25+ बाहरी डेटा स्रोत** - Google Drive, OneDrive, Dropbox, Notion और कई अन्य बाहरी सेवाओं से अपने स्रोत जोड़ें।
|
||||
- **रीयल-टाइम मल्टीप्लेयर सपोर्ट** - एक साझा notebook में अपनी टीम के सदस्यों के साथ आसानी से काम करें।
|
||||
- **डेस्कटॉप ऐप** - Quick Assist, General Assist, Extreme Assist और लोकल फ़ोल्डर सिंक के साथ किसी भी एप्लिकेशन में AI सहायता प्राप्त करें।
|
||||
|
||||
...और भी बहुत कुछ आने वाला है।
|
||||
|
||||
|
||||
|
||||
# डेमो
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## वीडियो एजेंट नमूना
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/012a7ffa-6f76-4f06-9dda-7632b470057a
|
||||
|
||||
|
||||
|
|
@ -68,42 +64,58 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
1. [surfsense.com](https://www.surfsense.com) पर जाएं और लॉगिन करें।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/b4df25fe-db5a-43c2-9462-b75cf7f1b707" alt="लॉगिन" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/LoginFlowGif.gif" alt="लॉगिन" /></p>
|
||||
|
||||
2. अपने कनेक्टर जोड़ें और सिंक करें। कनेक्टर्स को अपडेट रखने के लिए आवधिक सिंकिंग सक्षम करें।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="कनेक्टर्स" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ConnectorFlowGif.gif" alt="कनेक्टर्स" /></p>
|
||||
|
||||
3. जब तक कनेक्टर्स का डेटा इंडेक्स हो रहा है, दस्तावेज़ अपलोड करें।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="दस्तावेज़ अपलोड करें" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/DocUploadGif.gif" alt="दस्तावेज़ अपलोड करें" /></p>
|
||||
|
||||
4. सब कुछ इंडेक्स हो जाने के बाद, कुछ भी पूछें (उपयोग के मामले):
|
||||
|
||||
- डेस्कटॉप ऐप — General Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/general_assist.gif" alt="General Assist" /></p>
|
||||
|
||||
- डेस्कटॉप ऐप — Quick Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||
|
||||
- डेस्कटॉप ऐप — Extreme Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
||||
|
||||
- डेस्कटॉप ऐप — Watch Local Folder
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/folder_watch.gif" alt="Watch Local Folder" /></p>
|
||||
|
||||
- वीडियो जनरेशन
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="वीडियो जनरेशन" /></p>
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/video_gen_gif.gif" alt="वीडियो जनरेशन" /></p>
|
||||
|
||||
- बेसिक सर्च और उद्धरण
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="सर्च और उद्धरण" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BSNCGif.gif" alt="सर्च और उद्धरण" /></p>
|
||||
|
||||
- दस्तावेज़ मेंशन QNA
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="दस्तावेज़ मेंशन QNA" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="दस्तावेज़ मेंशन QNA" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="दस्तावेज़ मेंशन QNA" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="दस्तावेज़ मेंशन QNA" /></p>
|
||||
|
||||
- रिपोर्ट जनरेशन और एक्सपोर्ट (PDF, DOCX, HTML, LaTeX, EPUB, ODT, सादा टेक्स्ट)
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/9836b7d6-57c9-4951-b61c-68202c9b6ace" alt="रिपोर्ट जनरेशन" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ReportGenGif_compressed.gif" alt="रिपोर्ट जनरेशन" /></p>
|
||||
|
||||
- पॉडकास्ट जनरेशन
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/58c9b057-8848-4e81-aaba-d2c617985d8c" alt="पॉडकास्ट जनरेशन" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/PodcastGenGif.gif" alt="पॉडकास्ट जनरेशन" /></p>
|
||||
|
||||
- इमेज जनरेशन
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/25f94cb3-18f8-4854-afd9-27b7bfd079cb" alt="इमेज जनरेशन" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ImageGenGif.gif" alt="इमेज जनरेशन" /></p>
|
||||
|
||||
- और भी बहुत कुछ जल्द आ रहा है।
|
||||
|
||||
|
|
@ -130,6 +142,19 @@ irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/in
|
|||
|
||||
Docker Compose, मैनुअल इंस्टॉलेशन और अन्य डिप्लॉयमेंट विकल्पों के लिए, [डॉक्स](https://www.surfsense.com/docs/) देखें।
|
||||
|
||||
### डेस्कटॉप ऐप
|
||||
|
||||
SurfSense एक डेस्कटॉप ऐप भी प्रदान करता है जो आपके कंप्यूटर पर हर एप्लिकेशन में AI सहायता लाता है। इसे [नवीनतम रिलीज़](https://github.com/MODSetter/SurfSense/releases/latest) से डाउनलोड करें।
|
||||
|
||||
डेस्कटॉप ऐप में ये शक्तिशाली सुविधाएं शामिल हैं:
|
||||
|
||||
- **General Assist** — एक ग्लोबल शॉर्टकट से किसी भी एप्लिकेशन से तुरंत SurfSense लॉन्च करें।
|
||||
- **Quick Assist** — कहीं भी टेक्स्ट चुनें, फिर AI से समझाने, फिर से लिखने या उस पर कार्रवाई करने को कहें।
|
||||
- **Extreme Assist** — किसी भी ऐप में टाइप करते समय अपनी नॉलेज बेस से संचालित इनलाइन लेखन सुझाव प्राप्त करें।
|
||||
- **Watch Local Folder** — एक लोकल फ़ोल्डर को वॉच करें और फ़ाइल परिवर्तनों को स्वचालित रूप से अपनी नॉलेज बेस में सिंक करें। **Pro tip:** इसे अपने Obsidian vault पर पॉइंट करें ताकि आपके नोट्स SurfSense में सर्च करने योग्य रहें।
|
||||
|
||||
सभी सुविधाएं आपके चुने हुए सर्च स्पेस पर काम करती हैं, ताकि आपके उत्तर हमेशा आपके अपने डेटा पर आधारित हों।
|
||||
|
||||
### रीयल-टाइम सहयोग कैसे करें (बीटा)
|
||||
|
||||
1. सदस्य प्रबंधन पेज पर जाएं और एक आमंत्रण बनाएं।
|
||||
|
|
@ -146,11 +171,11 @@ Docker Compose, मैनुअल इंस्टॉलेशन और अन
|
|||
|
||||
4. आपकी टीम अब रीयल-टाइम में चैट कर सकती है।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/83803ac2-fbce-4d93-aae3-85eb85a3053a" alt="रीयल-टाइम चैट" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeChatGif.gif" alt="रीयल-टाइम चैट" /></p>
|
||||
|
||||
5. टीममेट्स को टैग करने के लिए कमेंट जोड़ें।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/3b04477d-8f42-4baa-be95-867c1eaeba87" alt="रीयल-टाइम कमेंट्स" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeCommentsFlow.gif" alt="रीयल-टाइम कमेंट्स" /></p>
|
||||
|
||||
## SurfSense vs Google NotebookLM
|
||||
|
||||
|
|
@ -174,6 +199,7 @@ Docker Compose, मैनुअल इंस्टॉलेशन और अन
|
|||
| **वीडियो जनरेशन** | Veo 3 के माध्यम से सिनेमैटिक वीडियो ओवरव्यू (केवल Ultra) | उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) |
|
||||
| **प्रेजेंटेशन जनरेशन** | बेहतर दिखने वाली स्लाइड्स लेकिन संपादन योग्य नहीं | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं |
|
||||
| **पॉडकास्ट जनरेशन** | कस्टमाइज़ेबल होस्ट और भाषाओं के साथ ऑडियो ओवरव्यू | कई TTS प्रदाताओं के साथ उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) |
|
||||
| **डेस्कटॉप ऐप** | नहीं | General Assist, Quick Assist, Extreme Assist और लोकल फ़ोल्डर सिंक के साथ नेटिव ऐप |
|
||||
| **ब्राउज़र एक्सटेंशन** | नहीं | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित |
|
||||
|
||||
<details>
|
||||
|
|
|
|||
63
README.md
63
README.md
|
|
@ -41,19 +41,14 @@ NotebookLM is one of the best and most useful AI platforms out there, but once y
|
|||
- **No Vendor Lock-in** - Configure any LLM, image, TTS, and STT models to use.
|
||||
- **25+ External Data Sources** - Add your sources from Google Drive, OneDrive, Dropbox, Notion, and many other external services.
|
||||
- **Real-Time Multiplayer Support** - Work easily with your team members in a shared notebook.
|
||||
- **Desktop App** - Get AI assistance in any application with Quick Assist, General Assist, Extreme Assist, and local folder sync.
|
||||
|
||||
...and more to come.
|
||||
|
||||
|
||||
|
||||
# Demo
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## Video Agent Sample
|
||||
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/012a7ffa-6f76-4f06-9dda-7632b470057a
|
||||
|
||||
|
||||
|
|
@ -69,42 +64,58 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
1. Go to [surfsense.com](https://www.surfsense.com) and login.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/b4df25fe-db5a-43c2-9462-b75cf7f1b707" alt="Login" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/LoginFlowGif.gif" alt="Login" /></p>
|
||||
|
||||
2. Connect your connectors and sync. Enable periodic syncing to keep connectors synced.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="Connectors" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ConnectorFlowGif.gif" alt="Connectors" /></p>
|
||||
|
||||
3. Till connectors data index, upload Documents.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="Upload Documents" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/DocUploadGif.gif" alt="Upload Documents" /></p>
|
||||
|
||||
4. Once everything is indexed, Ask Away (Use Cases):
|
||||
|
||||
- Desktop App — General Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/general_assist.gif" alt="General Assist" /></p>
|
||||
|
||||
- Desktop App — Quick Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||
|
||||
- Desktop App — Extreme Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
||||
|
||||
- Desktop App — Watch Local Folder
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/folder_watch.gif" alt="Watch Local Folder" /></p>
|
||||
|
||||
- Video Generation
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="Search and Citation" /></p>
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/video_gen_gif.gif" alt="Video Generation" /></p>
|
||||
|
||||
- Basic search and citation
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="Search and Citation" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BSNCGif.gif" alt="Search and Citation" /></p>
|
||||
|
||||
- Document Mention QNA
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="Document Mention QNA" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="Document Mention QNA" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="Document Mention QNA" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="Document Mention QNA" /></p>
|
||||
|
||||
- Report Generations and Exports (PDF, DOCX, HTML, LaTeX, EPUB, ODT, Plain Text)
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/9836b7d6-57c9-4951-b61c-68202c9b6ace" alt="Report Generation" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ReportGenGif_compressed.gif" alt="Report Generation" /></p>
|
||||
|
||||
- Podcast Generations
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/58c9b057-8848-4e81-aaba-d2c617985d8c" alt="Podcast Generation" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/PodcastGenGif.gif" alt="Podcast Generation" /></p>
|
||||
|
||||
- Image Generations
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/25f94cb3-18f8-4854-afd9-27b7bfd079cb" alt="Image Generation" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ImageGenGif.gif" alt="Image Generation" /></p>
|
||||
|
||||
- And more coming soon.
|
||||
|
||||
|
|
@ -131,6 +142,19 @@ The install script sets up [Watchtower](https://github.com/nicholas-fedor/watcht
|
|||
|
||||
For Docker Compose, manual installation, and other deployment options, see the [docs](https://www.surfsense.com/docs/).
|
||||
|
||||
### Desktop App
|
||||
|
||||
SurfSense also ships a desktop app that brings AI assistance to every application on your computer. Download it from the [latest release](https://github.com/MODSetter/SurfSense/releases/latest).
|
||||
|
||||
The desktop app includes these powerful features:
|
||||
|
||||
- **General Assist** — Launch SurfSense instantly from any application with a global shortcut.
|
||||
- **Quick Assist** — Select text anywhere, then ask AI to explain, rewrite, or act on it.
|
||||
- **Extreme Assist** — Get inline writing suggestions powered by your knowledge base as you type in any app.
|
||||
- **Watch Local Folder** — Watch a local folder and automatically sync file changes to your knowledge base. **Pro tip:** Point it at your Obsidian vault to keep your notes searchable in SurfSense.
|
||||
|
||||
All features operate against your chosen search space, so your answers are always grounded in your own data.
|
||||
|
||||
### How to Realtime Collaborate (Beta)
|
||||
|
||||
1. Go to Manage Members page and create an invite.
|
||||
|
|
@ -147,11 +171,11 @@ For Docker Compose, manual installation, and other deployment options, see the [
|
|||
|
||||
4. Your team can now chat in realtime.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/83803ac2-fbce-4d93-aae3-85eb85a3053a" alt="Realtime Chat" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeChatGif.gif" alt="Realtime Chat" /></p>
|
||||
|
||||
5. Add comment to tag teammates.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/3b04477d-8f42-4baa-be95-867c1eaeba87" alt="Realtime Comments" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeCommentsFlow.gif" alt="Realtime Comments" /></p>
|
||||
|
||||
## SurfSense vs Google NotebookLM
|
||||
|
||||
|
|
@ -175,6 +199,7 @@ For Docker Compose, manual installation, and other deployment options, see the [
|
|||
| **Video Generation** | Cinematic Video Overviews via Veo 3 (Ultra only) | Available (NotebookLM is better here, actively improving) |
|
||||
| **Presentation Generation** | Better looking slides but not editable | Create editable, slide-based presentations |
|
||||
| **Podcast Generation** | Audio Overviews with customizable hosts and languages | Available with multiple TTS providers (NotebookLM is better here, actively improving) |
|
||||
| **Desktop App** | No | Native app with General Assist, Quick Assist, Extreme Assist, and local folder sync |
|
||||
| **Browser Extension** | No | Cross-browser extension to save any webpage, including auth-protected pages |
|
||||
|
||||
<details>
|
||||
|
|
|
|||
|
|
@ -41,18 +41,14 @@ O NotebookLM é uma das melhores e mais úteis plataformas de IA disponíveis, m
|
|||
- **Sem Dependência de Fornecedor** - Configure qualquer modelo LLM, de imagem, TTS e STT.
|
||||
- **25+ Fontes de Dados Externas** - Adicione suas fontes do Google Drive, OneDrive, Dropbox, Notion e muitos outros serviços externos.
|
||||
- **Suporte Multiplayer em Tempo Real** - Trabalhe facilmente com os membros da sua equipe em um notebook compartilhado.
|
||||
- **Aplicativo Desktop** - Obtenha assistência de IA em qualquer aplicativo com Quick Assist, General Assist, Extreme Assist e sincronização de pastas locais.
|
||||
|
||||
...e mais por vir.
|
||||
|
||||
|
||||
|
||||
# Demo
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## Exemplo de Agente de Vídeo
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/012a7ffa-6f76-4f06-9dda-7632b470057a
|
||||
|
||||
|
||||
|
|
@ -68,42 +64,58 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
1. Acesse [surfsense.com](https://www.surfsense.com) e faça login.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/b4df25fe-db5a-43c2-9462-b75cf7f1b707" alt="Login" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/LoginFlowGif.gif" alt="Login" /></p>
|
||||
|
||||
2. Conecte seus conectores e sincronize. Ative a sincronização periódica para manter os conectores atualizados.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="Conectores" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ConnectorFlowGif.gif" alt="Conectores" /></p>
|
||||
|
||||
3. Enquanto os dados dos conectores são indexados, faça upload de documentos.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="Upload de Documentos" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/DocUploadGif.gif" alt="Upload de Documentos" /></p>
|
||||
|
||||
4. Quando tudo estiver indexado, pergunte o que quiser (Casos de uso):
|
||||
|
||||
- Aplicativo Desktop — General Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/general_assist.gif" alt="General Assist" /></p>
|
||||
|
||||
- Aplicativo Desktop — Quick Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||
|
||||
- Aplicativo Desktop — Extreme Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
||||
|
||||
- Aplicativo Desktop — Watch Local Folder
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/folder_watch.gif" alt="Watch Local Folder" /></p>
|
||||
|
||||
- Geração de vídeos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="Geração de Vídeos" /></p>
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/video_gen_gif.gif" alt="Geração de Vídeos" /></p>
|
||||
|
||||
- Busca básica e citações
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="Busca e Citação" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BSNCGif.gif" alt="Busca e Citação" /></p>
|
||||
|
||||
- QNA com menção de documentos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="QNA com Menção de Documentos" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="QNA com Menção de Documentos" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="QNA com Menção de Documentos" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="QNA com Menção de Documentos" /></p>
|
||||
|
||||
- Geração de relatórios e exportações (PDF, DOCX, HTML, LaTeX, EPUB, ODT, texto simples)
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/9836b7d6-57c9-4951-b61c-68202c9b6ace" alt="Geração de Relatórios" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ReportGenGif_compressed.gif" alt="Geração de Relatórios" /></p>
|
||||
|
||||
- Geração de podcasts
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/58c9b057-8848-4e81-aaba-d2c617985d8c" alt="Geração de Podcasts" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/PodcastGenGif.gif" alt="Geração de Podcasts" /></p>
|
||||
|
||||
- Geração de imagens
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/25f94cb3-18f8-4854-afd9-27b7bfd079cb" alt="Geração de Imagens" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ImageGenGif.gif" alt="Geração de Imagens" /></p>
|
||||
|
||||
- E mais em breve.
|
||||
|
||||
|
|
@ -130,6 +142,19 @@ O script de instalação configura o [Watchtower](https://github.com/nicholas-fe
|
|||
|
||||
Para Docker Compose, instalação manual e outras opções de implantação, consulte a [documentação](https://www.surfsense.com/docs/).
|
||||
|
||||
### Aplicativo Desktop
|
||||
|
||||
O SurfSense também oferece um aplicativo desktop que traz assistência de IA para cada aplicativo no seu computador. Baixe-o na [última versão](https://github.com/MODSetter/SurfSense/releases/latest).
|
||||
|
||||
O aplicativo desktop inclui estes recursos poderosos:
|
||||
|
||||
- **General Assist** — Abra o SurfSense instantaneamente de qualquer aplicativo com um atalho global.
|
||||
- **Quick Assist** — Selecione texto em qualquer lugar, depois peça à IA para explicar, reescrever ou agir sobre ele.
|
||||
- **Extreme Assist** — Receba sugestões de escrita em linha alimentadas pela sua base de conhecimento enquanto digita em qualquer aplicativo.
|
||||
- **Watch Local Folder** — Monitore uma pasta local e sincronize automaticamente as alterações de arquivos com sua base de conhecimento. **Pro tip:** Aponte para seu cofre do Obsidian para manter suas notas pesquisáveis no SurfSense.
|
||||
|
||||
Todos os recursos operam no espaço de busca escolhido, para que suas respostas sejam sempre baseadas nos seus próprios dados.
|
||||
|
||||
### Como Colaborar em Tempo Real (Beta)
|
||||
|
||||
1. Acesse a página de Gerenciar Membros e crie um convite.
|
||||
|
|
@ -146,11 +171,11 @@ Para Docker Compose, instalação manual e outras opções de implantação, con
|
|||
|
||||
4. Sua equipe agora pode conversar em tempo real.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/83803ac2-fbce-4d93-aae3-85eb85a3053a" alt="Chat em Tempo Real" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeChatGif.gif" alt="Chat em Tempo Real" /></p>
|
||||
|
||||
5. Adicione comentários para marcar colegas de equipe.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/3b04477d-8f42-4baa-be95-867c1eaeba87" alt="Comentários em Tempo Real" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeCommentsFlow.gif" alt="Comentários em Tempo Real" /></p>
|
||||
|
||||
## SurfSense vs Google NotebookLM
|
||||
|
||||
|
|
@ -174,6 +199,7 @@ Para Docker Compose, instalação manual e outras opções de implantação, con
|
|||
| **Geração de Vídeos** | Visões gerais cinemáticas via Veo 3 (apenas Ultra) | Disponível (NotebookLM é melhor aqui, melhorando ativamente) |
|
||||
| **Geração de Apresentações** | Slides mais bonitos mas não editáveis | Cria apresentações editáveis baseadas em slides |
|
||||
| **Geração de Podcasts** | Visões gerais em áudio com hosts e idiomas personalizáveis | Disponível com múltiplos provedores TTS (NotebookLM é melhor aqui, melhorando ativamente) |
|
||||
| **Aplicativo Desktop** | Não | Aplicativo nativo com General Assist, Quick Assist, Extreme Assist e sincronização de pastas locais |
|
||||
| **Extensão de Navegador** | Não | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação |
|
||||
|
||||
<details>
|
||||
|
|
|
|||
|
|
@ -41,18 +41,14 @@ NotebookLM 是目前最好、最实用的 AI 平台之一,但当你开始经
|
|||
- **无供应商锁定** - 配置任何 LLM、图像、TTS 和 STT 模型。
|
||||
- **25+ 外部数据源** - 从 Google Drive、OneDrive、Dropbox、Notion 和许多其他外部服务添加你的来源。
|
||||
- **实时多人协作支持** - 在共享笔记本中轻松与团队成员协作。
|
||||
- **桌面应用** - 通过 Quick Assist、General Assist、Extreme Assist 和本地文件夹同步在任何应用程序中获得 AI 助手。
|
||||
|
||||
...更多功能即将推出。
|
||||
|
||||
|
||||
|
||||
# 演示
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## 视频代理示例
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/012a7ffa-6f76-4f06-9dda-7632b470057a
|
||||
|
||||
|
||||
|
|
@ -68,42 +64,58 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
1. 访问 [surfsense.com](https://www.surfsense.com) 并登录。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/b4df25fe-db5a-43c2-9462-b75cf7f1b707" alt="登录" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/LoginFlowGif.gif" alt="登录" /></p>
|
||||
|
||||
2. 连接您的连接器并同步。启用定期同步以保持连接器数据更新。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="连接器" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ConnectorFlowGif.gif" alt="连接器" /></p>
|
||||
|
||||
3. 在连接器数据索引期间,上传文档。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="上传文档" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/DocUploadGif.gif" alt="上传文档" /></p>
|
||||
|
||||
4. 一切索引完成后,尽管提问(使用场景):
|
||||
|
||||
- 桌面应用 — General Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/general_assist.gif" alt="General Assist" /></p>
|
||||
|
||||
- 桌面应用 — Quick Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||
|
||||
- 桌面应用 — Extreme Assist
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
||||
|
||||
- 桌面应用 — Watch Local Folder
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/folder_watch.gif" alt="Watch Local Folder" /></p>
|
||||
|
||||
- 视频生成
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="视频生成" /></p>
|
||||
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/video_gen_gif.gif" alt="视频生成" /></p>
|
||||
|
||||
- 基本搜索和引用
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="搜索和引用" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BSNCGif.gif" alt="搜索和引用" /></p>
|
||||
|
||||
- 文档提及问答
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="文档提及问答" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="文档提及问答" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="文档提及问答" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/BQnaGif_compressed.gif" alt="文档提及问答" /></p>
|
||||
|
||||
- 报告生成和导出(PDF、DOCX、HTML、LaTeX、EPUB、ODT、纯文本)
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/9836b7d6-57c9-4951-b61c-68202c9b6ace" alt="报告生成" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ReportGenGif_compressed.gif" alt="报告生成" /></p>
|
||||
|
||||
- 播客生成
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/58c9b057-8848-4e81-aaba-d2c617985d8c" alt="播客生成" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/PodcastGenGif.gif" alt="播客生成" /></p>
|
||||
|
||||
- 图像生成
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/25f94cb3-18f8-4854-afd9-27b7bfd079cb" alt="图像生成" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/ImageGenGif.gif" alt="图像生成" /></p>
|
||||
|
||||
- 更多功能即将推出。
|
||||
|
||||
|
|
@ -130,6 +142,19 @@ irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/in
|
|||
|
||||
如需 Docker Compose、手动安装及其他部署方式,请查看[文档](https://www.surfsense.com/docs/)。
|
||||
|
||||
### 桌面应用
|
||||
|
||||
SurfSense 还提供桌面应用,将 AI 助手带到您计算机上的每个应用程序中。从[最新版本](https://github.com/MODSetter/SurfSense/releases/latest)下载。
|
||||
|
||||
桌面应用包含以下强大功能:
|
||||
|
||||
- **General Assist** — 通过全局快捷键从任何应用程序即时启动 SurfSense。
|
||||
- **Quick Assist** — 在任何位置选中文本,然后让 AI 解释、改写或对其执行操作。
|
||||
- **Extreme Assist** — 在任何应用中输入时,获得基于您知识库的内联写作建议。
|
||||
- **Watch Local Folder** — 监视本地文件夹,自动将文件更改同步到您的知识库。**Pro tip:** 将其指向您的 Obsidian vault,让笔记在 SurfSense 中随时可搜索。
|
||||
|
||||
所有功能均基于您选择的搜索空间运行,确保回答始终以您自己的数据为依据。
|
||||
|
||||
### 如何实时协作(Beta)
|
||||
|
||||
1. 前往成员管理页面并创建邀请。
|
||||
|
|
@ -146,11 +171,11 @@ irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/in
|
|||
|
||||
4. 您的团队现在可以实时聊天。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/83803ac2-fbce-4d93-aae3-85eb85a3053a" alt="实时聊天" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeChatGif.gif" alt="实时聊天" /></p>
|
||||
|
||||
5. 添加评论以标记队友。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/3b04477d-8f42-4baa-be95-867c1eaeba87" alt="实时评论" /></p>
|
||||
<p align="center"><img src="surfsense_web/public/homepage/hero_realtime/RealTimeCommentsFlow.gif" alt="实时评论" /></p>
|
||||
|
||||
## SurfSense vs Google NotebookLM
|
||||
|
||||
|
|
@ -174,6 +199,7 @@ irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/in
|
|||
| **视频生成** | 通过 Veo 3 的电影级视频概览(仅 Ultra) | 可用(NotebookLM 在此方面更好,正在积极改进) |
|
||||
| **演示文稿生成** | 更美观的幻灯片但不可编辑 | 创建可编辑的幻灯片式演示文稿 |
|
||||
| **播客生成** | 可自定义主持人和语言的音频概览 | 可用,支持多种 TTS 提供商(NotebookLM 在此方面更好,正在积极改进) |
|
||||
| **桌面应用** | 否 | 原生应用,包含 General Assist、Quick Assist、Extreme Assist 和本地文件夹同步 |
|
||||
| **浏览器扩展** | 否 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 |
|
||||
|
||||
<details>
|
||||
|
|
|
|||
|
|
@ -282,6 +282,9 @@ STT_SERVICE=local/base
|
|||
|
||||
# LlamaCloud (if ETL_SERVICE=LLAMACLOUD)
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
# Optional: Azure Document Intelligence accelerator (used with LLAMACLOUD)
|
||||
# AZURE_DI_ENDPOINT=https://your-resource.cognitiveservices.azure.com/
|
||||
# AZURE_DI_KEY=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Observability (optional)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ SurfSense 现已支持以下国产 LLM:
|
|||
|
||||
1. 登录 SurfSense Dashboard
|
||||
2. 进入 **Settings** → **API Keys** (或 **LLM Configurations**)
|
||||
3. 点击 **Add LLM Model**
|
||||
3. 点击 **Add Model**
|
||||
4. 从 **Provider** 下拉菜单中选择你的国产 LLM 提供商
|
||||
5. 填写必填字段(见下方各提供商详细配置)
|
||||
6. 点击 **Save**
|
||||
|
|
|
|||
6
package-lock.json
generated
Normal file
6
package-lock.json
generated
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"name": "SurfSense",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
5
package.json
Normal file
5
package.json
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"name": "surfsense",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.24.0"
|
||||
}
|
||||
|
|
@ -193,6 +193,9 @@ FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
|
|||
ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING
|
||||
UNSTRUCTURED_API_KEY=Tpu3P0U8iy
|
||||
LLAMA_CLOUD_API_KEY=llx-nnn
|
||||
# Optional: Azure Document Intelligence accelerator (used when ETL_SERVICE=LLAMACLOUD)
|
||||
# AZURE_DI_ENDPOINT=https://your-resource.cognitiveservices.azure.com/
|
||||
# AZURE_DI_KEY=your-key
|
||||
|
||||
# OPTIONAL: Add these for LangSmith Observability
|
||||
LANGSMITH_TRACING=true
|
||||
|
|
|
|||
|
|
@ -0,0 +1,149 @@
|
|||
"""Add LOCAL_FOLDER_FILE document type, folder metadata, and document_versions table
|
||||
|
||||
Revision ID: 118
|
||||
Revises: 117
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "118"
|
||||
down_revision: str | None = "117"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Add LOCAL_FOLDER_FILE to documenttype enum
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'LOCAL_FOLDER_FILE'
|
||||
) THEN
|
||||
ALTER TYPE documenttype ADD VALUE 'LOCAL_FOLDER_FILE';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add JSONB metadata column to folders table
|
||||
col_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = 'folders' AND column_name = 'metadata'"
|
||||
)
|
||||
).fetchone()
|
||||
if not col_exists:
|
||||
op.add_column(
|
||||
"folders",
|
||||
sa.Column("metadata", sa.dialects.postgresql.JSONB, nullable=True),
|
||||
)
|
||||
|
||||
# Create document_versions table
|
||||
table_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.tables WHERE table_name = 'document_versions'"
|
||||
)
|
||||
).fetchone()
|
||||
if not table_exists:
|
||||
op.create_table(
|
||||
"document_versions",
|
||||
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
|
||||
sa.Column("document_id", sa.Integer(), nullable=False),
|
||||
sa.Column("version_number", sa.Integer(), nullable=False),
|
||||
sa.Column("source_markdown", sa.Text(), nullable=True),
|
||||
sa.Column("content_hash", sa.String(), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["documents.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"document_id",
|
||||
"version_number",
|
||||
name="uq_document_version",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_versions_document_id "
|
||||
"ON document_versions (document_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_versions_created_at "
|
||||
"ON document_versions (created_at)"
|
||||
)
|
||||
|
||||
# Add document_versions to Zero publication
|
||||
pub_exists = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_publication WHERE pubname = :name"),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if pub_exists:
|
||||
already_in_pub = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = :name AND tablename = 'document_versions'"
|
||||
),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if not already_in_pub:
|
||||
op.execute(
|
||||
f"ALTER PUBLICATION {PUBLICATION_NAME} ADD TABLE document_versions"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Remove from publication
|
||||
pub_exists = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_publication WHERE pubname = :name"),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if pub_exists:
|
||||
already_in_pub = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = :name AND tablename = 'document_versions'"
|
||||
),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if already_in_pub:
|
||||
op.execute(
|
||||
f"ALTER PUBLICATION {PUBLICATION_NAME} DROP TABLE document_versions"
|
||||
)
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_document_versions_created_at")
|
||||
op.execute("DROP INDEX IF EXISTS ix_document_versions_document_id")
|
||||
op.execute("DROP TABLE IF EXISTS document_versions")
|
||||
|
||||
# Drop metadata column from folders
|
||||
col_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = 'folders' AND column_name = 'metadata'"
|
||||
)
|
||||
).fetchone()
|
||||
if col_exists:
|
||||
op.drop_column("folders", "metadata")
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""119_add_vision_llm_id_to_search_spaces
|
||||
|
||||
Revision ID: 119
|
||||
Revises: 118
|
||||
|
||||
Adds vision_llm_id column to search_spaces for vision/screenshot analysis
|
||||
LLM role assignment. Defaults to 0 (Auto mode), same convention as
|
||||
agent_llm_id and document_summary_llm_id.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "119"
|
||||
down_revision: str | None = "118"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(conn).get_columns("searchspaces")
|
||||
]
|
||||
|
||||
if "vision_llm_id" not in existing_columns:
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_id", sa.Integer(), nullable=True, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_llm_id")
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""Add vision LLM configs table and rename preference column
|
||||
|
||||
Revision ID: 120
|
||||
Revises: 119
|
||||
|
||||
Changes:
|
||||
1. Create visionprovider enum type
|
||||
2. Create vision_llm_configs table
|
||||
3. Rename vision_llm_id -> vision_llm_config_id on searchspaces
|
||||
4. Add vision config permissions to existing system roles
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import ENUM as PG_ENUM, UUID
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "120"
|
||||
down_revision: str | None = "119"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
VISION_PROVIDER_VALUES = (
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"OLLAMA",
|
||||
"GROQ",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"DEEPSEEK",
|
||||
"MISTRAL",
|
||||
"CUSTOM",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
# 1. Create visionprovider enum
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'visionprovider') THEN
|
||||
CREATE TYPE visionprovider AS ENUM (
|
||||
'OPENAI', 'ANTHROPIC', 'GOOGLE', 'AZURE_OPENAI', 'VERTEX_AI',
|
||||
'BEDROCK', 'XAI', 'OPENROUTER', 'OLLAMA', 'GROQ',
|
||||
'TOGETHER_AI', 'FIREWORKS_AI', 'DEEPSEEK', 'MISTRAL', 'CUSTOM'
|
||||
);
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create vision_llm_configs table
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'vision_llm_configs')"
|
||||
)
|
||||
)
|
||||
if not result.scalar():
|
||||
op.create_table(
|
||||
"vision_llm_configs",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("description", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"provider",
|
||||
PG_ENUM(
|
||||
*VISION_PROVIDER_VALUES, name="visionprovider", create_type=False
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("custom_provider", sa.String(100), nullable=True),
|
||||
sa.Column("model_name", sa.String(100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(500), nullable=True),
|
||||
sa.Column("api_version", sa.String(50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_vision_llm_configs_name "
|
||||
"ON vision_llm_configs (name)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_vision_llm_configs_search_space_id "
|
||||
"ON vision_llm_configs (search_space_id)"
|
||||
)
|
||||
|
||||
# 3. Rename vision_llm_id -> vision_llm_config_id on searchspaces
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(connection).get_columns("searchspaces")
|
||||
]
|
||||
if (
|
||||
"vision_llm_id" in existing_columns
|
||||
and "vision_llm_config_id" not in existing_columns
|
||||
):
|
||||
op.alter_column(
|
||||
"searchspaces", "vision_llm_id", new_column_name="vision_llm_config_id"
|
||||
)
|
||||
elif "vision_llm_config_id" not in existing_columns:
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
"vision_llm_config_id", sa.Integer(), nullable=True, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Add vision config permissions to existing system roles
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_cat(
|
||||
permissions,
|
||||
ARRAY['vision_configs:create', 'vision_configs:read']
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
AND name = 'Editor'
|
||||
AND NOT ('vision_configs:create' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_cat(
|
||||
permissions,
|
||||
ARRAY['vision_configs:read']
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
AND name = 'Viewer'
|
||||
AND NOT ('vision_configs:read' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
# Remove permissions
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(
|
||||
array_remove(
|
||||
array_remove(permissions, 'vision_configs:create'),
|
||||
'vision_configs:read'
|
||||
),
|
||||
'vision_configs:delete'
|
||||
)
|
||||
WHERE is_system_role = true
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename column back
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(connection).get_columns("searchspaces")
|
||||
]
|
||||
if "vision_llm_config_id" in existing_columns:
|
||||
op.alter_column(
|
||||
"searchspaces", "vision_llm_config_id", new_column_name="vision_llm_id"
|
||||
)
|
||||
|
||||
# Drop table and enum
|
||||
op.execute("DROP INDEX IF EXISTS ix_vision_llm_configs_search_space_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_vision_llm_configs_name")
|
||||
op.execute("DROP TABLE IF EXISTS vision_llm_configs")
|
||||
op.execute("DROP TYPE IF EXISTS visionprovider")
|
||||
|
|
@ -17,10 +17,10 @@ depends_on: str | Sequence[str] | None = None
|
|||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Add the new_llm_configs table that combines LLM model settings with prompt configuration.
|
||||
Add the new_llm_configs table that combines model settings with prompt configuration.
|
||||
|
||||
This table includes:
|
||||
- LLM model configuration (provider, model_name, api_key, etc.)
|
||||
- Model configuration (provider, model_name, api_key, etc.)
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
"""
|
||||
|
|
@ -41,7 +41,7 @@ def upgrade() -> None:
|
|||
name VARCHAR(100) NOT NULL,
|
||||
description VARCHAR(500),
|
||||
|
||||
-- LLM Model Configuration (same as llm_configs, excluding language)
|
||||
-- Model Configuration (same as llm_configs, excluding language)
|
||||
provider litellmprovider NOT NULL,
|
||||
custom_provider VARCHAR(100),
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
|
|
|
|||
11
surfsense_backend/app/agents/autocomplete/__init__.py
Normal file
11
surfsense_backend/app/agents/autocomplete/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Agent-based vision autocomplete with scoped filesystem exploration."""
|
||||
|
||||
from app.agents.autocomplete.autocomplete_agent import (
|
||||
create_autocomplete_agent,
|
||||
stream_autocomplete_agent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_autocomplete_agent",
|
||||
"stream_autocomplete_agent",
|
||||
]
|
||||
495
surfsense_backend/app/agents/autocomplete/autocomplete_agent.py
Normal file
495
surfsense_backend/app/agents/autocomplete/autocomplete_agent.py
Normal file
|
|
@ -0,0 +1,495 @@
|
|||
"""Vision autocomplete agent with scoped filesystem exploration.
|
||||
|
||||
Converts the stateless single-shot vision autocomplete into an agent that
|
||||
seeds a virtual filesystem from KB search results and lets the vision LLM
|
||||
explore documents via ``ls``, ``read_file``, ``glob``, ``grep``, etc.
|
||||
before generating the final completion.
|
||||
|
||||
Performance: KB search and agent graph compilation run in parallel so
|
||||
the only sequential latency is KB-search (or agent compile, whichever is
|
||||
slower) + the agent's LLM turns. There is no separate "query extraction"
|
||||
LLM call — the window title is used directly as the KB search query.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from deepagents.graph import BASE_AGENT_PROMPT
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from langchain.agents import create_agent
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
build_scoped_filesystem,
|
||||
search_knowledge_base,
|
||||
)
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 10
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AUTOCOMPLETE_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||
|
||||
You will receive a screenshot of the user's screen. Your PRIMARY source of truth is the screenshot itself — the visual context determines what to write.
|
||||
|
||||
Your job:
|
||||
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||
2. Identify the text area where the user will type.
|
||||
3. Generate the text the user most likely wants to write based on the visual context.
|
||||
|
||||
You also have access to the user's knowledge base documents via filesystem tools. However:
|
||||
- ONLY consult the knowledge base if the screenshot clearly involves a topic where your KB documents are DIRECTLY relevant (e.g., the user is writing about a specific project/topic that matches a document title).
|
||||
- Do NOT explore documents just because they exist. Most autocomplete requests can be answered purely from the screenshot.
|
||||
- If you do read a document, only incorporate information that is 100% relevant to what the user is typing RIGHT NOW. Do not add extra details, background, or tangential information from the KB.
|
||||
- Keep your output SHORT — autocomplete should feel like a natural continuation, not an essay.
|
||||
|
||||
Key behavior:
|
||||
- If the text area is EMPTY, draft a concise response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||
- If the text area already has text, continue it naturally — typically just a sentence or two.
|
||||
|
||||
Rules:
|
||||
- Be CONCISE. Prefer a single paragraph or a few sentences. Autocomplete is a quick assist, not a full draft.
|
||||
- Match the tone and formality of the surrounding context.
|
||||
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||
- Do NOT describe the screenshot or explain your reasoning.
|
||||
- Do NOT cite or reference documents explicitly — just let the knowledge inform your writing naturally.
|
||||
- If you cannot determine what to write, output an empty JSON array: []
|
||||
|
||||
## Output Format
|
||||
|
||||
You MUST provide exactly 3 different suggestion options. Each should be a distinct, plausible completion — vary the tone, detail level, or angle.
|
||||
|
||||
Return your suggestions as a JSON array of exactly 3 strings. Output ONLY the JSON array, nothing else — no markdown fences, no explanation, no commentary.
|
||||
|
||||
Example format:
|
||||
["First suggestion text here.", "Second suggestion — a different take.", "Third option with another approach."]
|
||||
|
||||
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`
|
||||
|
||||
All file paths must start with a `/`.
|
||||
- ls: list files and directories at a given path.
|
||||
- read_file: read a file from the filesystem.
|
||||
- write_file: create a temporary file in the session (not persisted).
|
||||
- edit_file: edit a file in the session (not persisted for /documents/ files).
|
||||
- glob: find files matching a pattern (e.g., "**/*.xml").
|
||||
- grep: search for text within files.
|
||||
|
||||
## When to Use Filesystem Tools
|
||||
|
||||
BEFORE reaching for any tool, ask yourself: "Can I write a good completion purely from the screenshot?" If yes, just write it — do NOT explore the KB.
|
||||
|
||||
Only use tools when:
|
||||
- The user is clearly writing about a specific topic that likely has detailed information in their KB.
|
||||
- You need a specific fact, name, number, or reference that the screenshot doesn't provide.
|
||||
|
||||
When you do use tools, be surgical:
|
||||
- Check the `ls` output first. If no document title looks relevant, stop — do not read files just to see what's there.
|
||||
- If a title looks relevant, read only the `<chunk_index>` (first ~20 lines) and jump to matched chunks. Do not read entire documents.
|
||||
- Extract only the specific information you need and move on to generating the completion.
|
||||
|
||||
## Reading Documents Efficiently
|
||||
|
||||
Documents are formatted as XML. Each document contains:
|
||||
- `<document_metadata>` — title, type, URL, etc.
|
||||
- `<chunk_index>` — a table of every chunk with its **line range** and a
|
||||
`matched="true"` flag for chunks that matched the search query.
|
||||
- `<document_content>` — the actual chunks in original document order.
|
||||
|
||||
**Workflow**: read the first ~20 lines to see the `<chunk_index>`, identify
|
||||
chunks marked `matched="true"`, then use `read_file(path, offset=<start_line>,
|
||||
limit=<lines>)` to jump directly to those sections."""
|
||||
|
||||
APP_CONTEXT_BLOCK = """
|
||||
|
||||
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
|
||||
|
||||
|
||||
def _build_autocomplete_system_prompt(app_name: str, window_title: str) -> str:
|
||||
prompt = AUTOCOMPLETE_SYSTEM_PROMPT
|
||||
if app_name:
|
||||
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
|
||||
return prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-compute KB filesystem (runs in parallel with agent compilation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _KBResult:
|
||||
"""Container for pre-computed KB filesystem results."""
|
||||
|
||||
__slots__ = ("files", "ls_ai_msg", "ls_tool_msg")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
files: dict[str, Any] | None = None,
|
||||
ls_ai_msg: AIMessage | None = None,
|
||||
ls_tool_msg: ToolMessage | None = None,
|
||||
) -> None:
|
||||
self.files = files
|
||||
self.ls_ai_msg = ls_ai_msg
|
||||
self.ls_tool_msg = ls_tool_msg
|
||||
|
||||
@property
|
||||
def has_documents(self) -> bool:
|
||||
return bool(self.files)
|
||||
|
||||
|
||||
async def precompute_kb_filesystem(
|
||||
search_space_id: int,
|
||||
query: str,
|
||||
top_k: int = KB_TOP_K,
|
||||
) -> _KBResult:
|
||||
"""Search the KB and build the scoped filesystem outside the agent.
|
||||
|
||||
This is designed to be called via ``asyncio.gather`` alongside agent
|
||||
graph compilation so the two run concurrently.
|
||||
"""
|
||||
if not query:
|
||||
return _KBResult()
|
||||
|
||||
try:
|
||||
search_results = await search_knowledge_base(
|
||||
query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
return _KBResult()
|
||||
|
||||
new_files, _ = await build_scoped_filesystem(
|
||||
documents=search_results,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if not new_files:
|
||||
return _KBResult()
|
||||
|
||||
doc_paths = [
|
||||
p
|
||||
for p, v in new_files.items()
|
||||
if p.startswith("/documents/") and v is not None
|
||||
]
|
||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||
ai_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}
|
||||
],
|
||||
)
|
||||
tool_msg = ToolMessage(
|
||||
content=str(doc_paths) if doc_paths else "No documents found.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return _KBResult(files=new_files, ls_ai_msg=ai_msg, ls_tool_msg=tool_msg)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"KB pre-computation failed, proceeding without KB", exc_info=True
|
||||
)
|
||||
return _KBResult()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filesystem middleware — no save_document, no persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
||||
"""Filesystem middleware for autocomplete — read-only exploration only.
|
||||
|
||||
Strips ``save_document`` (permanent KB persistence) and passes
|
||||
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(search_space_id=None, created_by_id=None)
|
||||
self.tools = [t for t in self.tools if t.name != "save_document"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _compile_agent(
|
||||
llm: BaseChatModel,
|
||||
app_name: str,
|
||||
window_title: str,
|
||||
) -> Any:
|
||||
"""Compile the agent graph (CPU-bound, runs in a thread)."""
|
||||
system_prompt = _build_autocomplete_system_prompt(app_name, window_title)
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
middleware = [
|
||||
AutocompleteFilesystemMiddleware(),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
agent = await asyncio.to_thread(
|
||||
create_agent,
|
||||
llm,
|
||||
system_prompt=final_system_prompt,
|
||||
tools=[],
|
||||
middleware=middleware,
|
||||
)
|
||||
return agent.with_config({"recursion_limit": 200})
|
||||
|
||||
|
||||
async def create_autocomplete_agent(
|
||||
llm: BaseChatModel,
|
||||
*,
|
||||
search_space_id: int,
|
||||
kb_query: str,
|
||||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> tuple[Any, _KBResult]:
|
||||
"""Create the autocomplete agent and pre-compute KB in parallel.
|
||||
|
||||
Returns ``(agent, kb_result)`` so the caller can inject the pre-computed
|
||||
filesystem into the agent's initial state without any middleware delay.
|
||||
"""
|
||||
agent, kb = await asyncio.gather(
|
||||
_compile_agent(llm, app_name, window_title),
|
||||
precompute_kb_filesystem(search_space_id, kb_query),
|
||||
)
|
||||
return agent, kb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON suggestion parsing (with fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_suggestions(raw: str) -> list[str]:
|
||||
"""Extract a list of suggestion strings from the agent's output.
|
||||
|
||||
Tries, in order:
|
||||
1. Direct ``json.loads``
|
||||
2. Extract content between ```json ... ``` fences
|
||||
3. Find the first ``[`` … ``]`` span
|
||||
Falls back to wrapping the raw text as a single suggestion.
|
||||
"""
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
for candidate in _json_candidates(text):
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, list) and all(isinstance(s, str) for s in parsed):
|
||||
return [s for s in parsed if s.strip()]
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
return [text]
|
||||
|
||||
|
||||
def _json_candidates(text: str) -> list[str]:
|
||||
"""Yield candidate JSON strings from raw text."""
|
||||
candidates = [text]
|
||||
|
||||
fence = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
|
||||
if fence:
|
||||
candidates.append(fence.group(1).strip())
|
||||
|
||||
bracket = re.search(r"\[.*]", text, re.DOTALL)
|
||||
if bracket:
|
||||
candidates.append(bracket.group(0))
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def stream_autocomplete_agent(
|
||||
agent: Any,
|
||||
input_data: dict[str, Any],
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
emit_message_start: bool = True,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream agent events as Vercel SSE, with thinking steps for tool calls.
|
||||
|
||||
When ``emit_message_start`` is False the caller has already sent the
|
||||
``message_start`` event (e.g. to show preparation steps before the agent
|
||||
runs).
|
||||
"""
|
||||
thread_id = uuid.uuid4().hex
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
text_buffer: list[str] = []
|
||||
active_tool_depth = 0
|
||||
thinking_step_counter = 0
|
||||
tool_step_ids: dict[str, str] = {}
|
||||
step_titles: dict[str, str] = {}
|
||||
completed_step_ids: set[str] = set()
|
||||
last_active_step_id: str | None = None
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
thinking_step_counter += 1
|
||||
return f"autocomplete-step-{thinking_step_counter}"
|
||||
|
||||
def complete_current_step() -> str | None:
|
||||
nonlocal last_active_step_id
|
||||
if last_active_step_id and last_active_step_id not in completed_step_ids:
|
||||
completed_step_ids.add(last_active_step_id)
|
||||
title = step_titles.get(last_active_step_id, "Done")
|
||||
event = streaming_service.format_thinking_step(
|
||||
step_id=last_active_step_id,
|
||||
title=title,
|
||||
status="complete",
|
||||
)
|
||||
last_active_step_id = None
|
||||
return event
|
||||
return None
|
||||
|
||||
if emit_message_start:
|
||||
yield streaming_service.format_message_start()
|
||||
|
||||
gen_step_id = next_thinking_step_id()
|
||||
last_active_step_id = gen_step_id
|
||||
step_titles[gen_step_id] = "Generating suggestions"
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=gen_step_id,
|
||||
title="Generating suggestions",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in agent.astream_events(
|
||||
input_data, config=config, version="v2"
|
||||
):
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "on_chat_model_stream":
|
||||
if active_tool_depth > 0:
|
||||
continue
|
||||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
if content and isinstance(content, str):
|
||||
text_buffer.append(content)
|
||||
|
||||
elif event_type == "on_chat_model_end":
|
||||
if active_tool_depth > 0:
|
||||
continue
|
||||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue
|
||||
output = event.get("data", {}).get("output")
|
||||
if output and hasattr(output, "content"):
|
||||
if getattr(output, "tool_calls", None):
|
||||
continue
|
||||
content = output.content
|
||||
if content and isinstance(content, str) and not text_buffer:
|
||||
text_buffer.append(content)
|
||||
|
||||
elif event_type == "on_tool_start":
|
||||
active_tool_depth += 1
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
|
||||
tool_step_id = next_thinking_step_id()
|
||||
tool_step_ids[run_id] = tool_step_id
|
||||
last_active_step_id = tool_step_id
|
||||
|
||||
title, items = _describe_tool_call(tool_name, tool_input)
|
||||
step_titles[tool_step_id] = title
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title=title,
|
||||
status="in_progress",
|
||||
items=items,
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
active_tool_depth = max(0, active_tool_depth - 1)
|
||||
run_id = event.get("run_id", "")
|
||||
step_id = tool_step_ids.pop(run_id, None)
|
||||
if step_id and step_id not in completed_step_ids:
|
||||
completed_step_ids.add(step_id)
|
||||
title = step_titles.get(step_id, "Done")
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=step_id,
|
||||
title=title,
|
||||
status="complete",
|
||||
)
|
||||
if last_active_step_id == step_id:
|
||||
last_active_step_id = None
|
||||
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
|
||||
raw_text = "".join(text_buffer)
|
||||
suggestions = _parse_suggestions(raw_text)
|
||||
|
||||
yield streaming_service.format_data("suggestions", {"options": suggestions})
|
||||
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Autocomplete agent streaming error: {e}", exc_info=True)
|
||||
yield streaming_service.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming_service.format_done()
|
||||
|
||||
|
||||
def _describe_tool_call(tool_name: str, tool_input: Any) -> tuple[str, list[str]]:
|
||||
"""Return a human-readable (title, items) for a tool call thinking step."""
|
||||
inp = tool_input if isinstance(tool_input, dict) else {}
|
||||
if tool_name == "ls":
|
||||
path = inp.get("path", "/")
|
||||
return "Listing files", [path]
|
||||
if tool_name == "read_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Reading file", [display]
|
||||
if tool_name == "write_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Writing file", [display]
|
||||
if tool_name == "edit_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Editing file", [display]
|
||||
if tool_name == "glob":
|
||||
pat = inp.get("pattern", "")
|
||||
base = inp.get("path", "/")
|
||||
return "Searching files", [f"{pat} in {base}"]
|
||||
if tool_name == "grep":
|
||||
pat = inp.get("pattern", "")
|
||||
path = inp.get("path", "")
|
||||
display_pat = pat[:60] + ("…" if len(pat) > 60 else "")
|
||||
return "Searching content", [
|
||||
f'"{display_pat}"' + (f" in {path}" if path else "")
|
||||
]
|
||||
return f"Using {tool_name}", []
|
||||
|
|
@ -25,7 +25,12 @@ from app.agents.new_chat.checkpointer import (
|
|||
close_checkpointer,
|
||||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.config import config, initialize_image_gen_router, initialize_llm_router
|
||||
from app.config import (
|
||||
config,
|
||||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
|
|
@ -223,6 +228,7 @@ async def lifespan(app: FastAPI):
|
|||
await setup_checkpointer_tables()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
try:
|
||||
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
|
||||
except TimeoutError:
|
||||
|
|
|
|||
|
|
@ -18,10 +18,15 @@ def init_worker(**kwargs):
|
|||
This ensures the Auto mode (LiteLLM Router) is available for background tasks
|
||||
like document summarization and image generation.
|
||||
"""
|
||||
from app.config import initialize_image_gen_router, initialize_llm_router
|
||||
from app.config import (
|
||||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
|
||||
# Get Celery configuration from environment
|
||||
|
|
|
|||
|
|
@ -102,6 +102,44 @@ def load_global_image_gen_configs():
|
|||
return []
|
||||
|
||||
|
||||
def load_global_vision_llm_configs():
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_vision_llm_configs", [])
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_vision_llm_router_settings():
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load vision LLM router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
|
@ -182,6 +220,29 @@ def initialize_image_gen_router():
|
|||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
def initialize_vision_llm_router():
|
||||
vision_configs = load_global_vision_llm_configs()
|
||||
router_settings = load_vision_llm_router_settings()
|
||||
|
||||
if not vision_configs:
|
||||
print(
|
||||
"Info: No global vision LLM configs found, "
|
||||
"Vision LLM Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
VisionLLMRouterService.initialize(vision_configs, router_settings)
|
||||
print(
|
||||
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -335,6 +396,12 @@ class Config:
|
|||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Global Vision LLM Configurations (optional)
|
||||
GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs()
|
||||
|
||||
# Router settings for Vision LLM Auto mode
|
||||
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
# Azure OpenAI credentials from environment variables
|
||||
|
|
@ -394,8 +461,10 @@ class Config:
|
|||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
||||
elif ETL_SERVICE == "LLAMACLOUD":
|
||||
# LlamaCloud API Key
|
||||
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
# Optional: Azure Document Intelligence accelerator for supported file types
|
||||
AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT")
|
||||
AZURE_DI_KEY = os.getenv("AZURE_DI_KEY")
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling and YouTube transcript fetching to avoid IP bans.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
# - Configure router_settings below to customize the load balancing behavior
|
||||
#
|
||||
# Structure matches NewLLMConfig:
|
||||
# - LLM model configuration (provider, model_name, api_key, etc.)
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
|
|
@ -263,6 +263,82 @@ global_image_generation_configs:
|
|||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM Configuration
|
||||
# =============================================================================
|
||||
# These configurations power the vision autocomplete feature (screenshot analysis).
|
||||
# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3).
|
||||
# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock,
|
||||
# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs.
|
||||
|
||||
# Router Settings for Vision LLM Auto Mode
|
||||
vision_llm_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_vision_llm_configs:
|
||||
# Example: OpenAI GPT-4o (recommended for vision)
|
||||
- id: -1
|
||||
name: "Global GPT-4o Vision"
|
||||
description: "OpenAI's GPT-4o with strong vision capabilities"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Google Gemini 2.0 Flash
|
||||
- id: -2
|
||||
name: "Global Gemini 2.0 Flash"
|
||||
description: "Google's fast vision model with large context"
|
||||
provider: "GOOGLE"
|
||||
model_name: "gemini-2.0-flash"
|
||||
api_key: "your-google-ai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Anthropic Claude 3.5 Sonnet
|
||||
- id: -3
|
||||
name: "Global Claude 3.5 Sonnet Vision"
|
||||
description: "Anthropic's Claude 3.5 Sonnet with vision support"
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-5-sonnet-20241022"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# - id: -4
|
||||
# name: "Global Azure GPT-4o Vision"
|
||||
# description: "Azure-hosted GPT-4o for vision analysis"
|
||||
# provider: "AZURE_OPENAI"
|
||||
# model_name: "azure/gpt-4o-deployment"
|
||||
# api_key: "your-azure-api-key-here"
|
||||
# api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2024-02-15-preview"
|
||||
# rpm: 500
|
||||
# tpm: 100000
|
||||
# litellm_params:
|
||||
# temperature: 0.3
|
||||
# max_tokens: 1000
|
||||
# base_model: "gpt-4o"
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
|
|
@ -283,3 +359,9 @@ global_image_generation_configs:
|
|||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
|
||||
#
|
||||
# VISION LLM NOTES:
|
||||
# - Vision configs use the same ID scheme (negative for global, positive for user DB)
|
||||
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
|
||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||
|
|
|
|||
23
surfsense_backend/app/config/vision_model_list_fallback.json
Normal file
23
surfsense_backend/app/config/vision_model_list_fallback.json
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
[
|
||||
{"value": "gpt-4o", "label": "GPT-4o", "provider": "OPENAI", "context_window": "128K"},
|
||||
{"value": "gpt-4o-mini", "label": "GPT-4o Mini", "provider": "OPENAI", "context_window": "128K"},
|
||||
{"value": "gpt-4-turbo", "label": "GPT-4 Turbo", "provider": "OPENAI", "context_window": "128K"},
|
||||
{"value": "claude-sonnet-4-20250514", "label": "Claude Sonnet 4", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-7-sonnet-20250219", "label": "Claude 3.7 Sonnet", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-5-sonnet-20241022", "label": "Claude 3.5 Sonnet", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-opus-20240229", "label": "Claude 3 Opus", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-haiku-20240307", "label": "Claude 3 Haiku", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "gemini-2.5-flash", "label": "Gemini 2.5 Flash", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-2.5-pro", "label": "Gemini 2.5 Pro", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-2.0-flash", "label": "Gemini 2.0 Flash", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-1.5-pro", "label": "Gemini 1.5 Pro", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-1.5-flash", "label": "Gemini 1.5 Flash", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "pixtral-large-latest", "label": "Pixtral Large", "provider": "MISTRAL", "context_window": "128K"},
|
||||
{"value": "pixtral-12b-2409", "label": "Pixtral 12B", "provider": "MISTRAL", "context_window": "128K"},
|
||||
{"value": "grok-2-vision-1212", "label": "Grok 2 Vision", "provider": "XAI", "context_window": "32K"},
|
||||
{"value": "llava", "label": "LLaVA", "provider": "OLLAMA"},
|
||||
{"value": "bakllava", "label": "BakLLaVA", "provider": "OLLAMA"},
|
||||
{"value": "llava-llama3", "label": "LLaVA Llama 3", "provider": "OLLAMA"},
|
||||
{"value": "llama-4-scout-17b-16e-instruct", "label": "Llama 4 Scout 17B", "provider": "GROQ", "context_window": "128K"},
|
||||
{"value": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "label": "Llama 4 Scout 17B", "provider": "TOGETHER_AI", "context_window": "128K"}
|
||||
]
|
||||
|
|
@ -225,6 +225,55 @@ class DropboxClient:
|
|||
|
||||
return all_items, None
|
||||
|
||||
async def get_latest_cursor(self, path: str = "") -> tuple[str | None, str | None]:
|
||||
"""Get a cursor representing the current state of a folder.
|
||||
|
||||
Uses /2/files/list_folder/get_latest_cursor so we can later call
|
||||
get_changes to receive only incremental updates.
|
||||
"""
|
||||
resp = await self._request(
|
||||
"/2/files/list_folder/get_latest_cursor",
|
||||
{"path": path, "recursive": False, "include_non_downloadable_files": True},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Failed to get cursor: {resp.status_code} - {resp.text}"
|
||||
return resp.json().get("cursor"), None
|
||||
|
||||
async def get_changes(
|
||||
self, cursor: str
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
"""Fetch incremental changes since the given cursor.
|
||||
|
||||
Calls /2/files/list_folder/continue and handles pagination.
|
||||
Returns (entries, new_cursor, error).
|
||||
"""
|
||||
all_entries: list[dict[str, Any]] = []
|
||||
|
||||
resp = await self._request("/2/files/list_folder/continue", {"cursor": cursor})
|
||||
if resp.status_code == 401:
|
||||
return [], None, "Dropbox authentication expired (401)"
|
||||
if resp.status_code != 200:
|
||||
return [], None, f"Failed to get changes: {resp.status_code} - {resp.text}"
|
||||
|
||||
data = resp.json()
|
||||
all_entries.extend(data.get("entries", []))
|
||||
|
||||
while data.get("has_more"):
|
||||
cursor = data["cursor"]
|
||||
resp = await self._request(
|
||||
"/2/files/list_folder/continue", {"cursor": cursor}
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return (
|
||||
all_entries,
|
||||
data.get("cursor"),
|
||||
f"Pagination failed: {resp.status_code}",
|
||||
)
|
||||
data = resp.json()
|
||||
all_entries.extend(data.get("entries", []))
|
||||
|
||||
return all_entries, data.get("cursor"), None
|
||||
|
||||
async def get_metadata(self, path: str) -> tuple[dict[str, Any] | None, str | None]:
|
||||
resp = await self._request("/2/files/get_metadata", {"path": path})
|
||||
if resp.status_code != 200:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,8 @@ async def download_and_extract_content(
|
|||
file_name = file.get("name", "Unknown")
|
||||
file_id = file.get("id", "")
|
||||
|
||||
if should_skip_file(file):
|
||||
skip, _unsup_ext = should_skip_file(file)
|
||||
if skip:
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name}")
|
||||
|
|
@ -87,9 +88,13 @@ async def download_and_extract_content(
|
|||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
from app.connectors.onedrive.content_extractor import _parse_file_to_markdown
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=temp_file_path, filename=file_name)
|
||||
)
|
||||
markdown = result.markdown_content
|
||||
return markdown, metadata, None
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""File type handlers for Dropbox."""
|
||||
|
||||
PAPER_EXTENSION = ".paper"
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
SKIP_EXTENSIONS: frozenset[str] = frozenset()
|
||||
PAPER_EXTENSION = ".paper"
|
||||
|
||||
MIME_TO_EXTENSION: dict[str, str] = {
|
||||
"application/pdf": ".pdf",
|
||||
|
|
@ -42,17 +42,25 @@ def is_paper_file(item: dict) -> bool:
|
|||
return ext == PAPER_EXTENSION
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
def should_skip_file(item: dict) -> tuple[bool, str | None]:
|
||||
"""Skip folders and truly non-indexable files.
|
||||
|
||||
Paper docs are non-downloadable but exportable, so they are NOT skipped.
|
||||
Returns (should_skip, unsupported_extension_or_None).
|
||||
"""
|
||||
if is_folder(item):
|
||||
return True
|
||||
return True, None
|
||||
if is_paper_file(item):
|
||||
return False
|
||||
return False, None
|
||||
if not item.get("is_downloadable", True):
|
||||
return True
|
||||
return True, None
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
name = item.get("name", "")
|
||||
ext = get_extension_from_name(name).lower()
|
||||
return ext in SKIP_EXTENSIONS
|
||||
if should_skip_for_service(name, app_config.ETL_SERVICE):
|
||||
ext = PurePosixPath(name).suffix.lower()
|
||||
return True, ext
|
||||
return False, None
|
||||
|
|
|
|||
|
|
@ -64,8 +64,10 @@ async def get_files_in_folder(
|
|||
)
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
else:
|
||||
skip, _unsup_ext = should_skip_file(item)
|
||||
if not skip:
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
"""Content extraction for Google Drive files."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -20,6 +17,7 @@ from .file_types import (
|
|||
get_export_mime_type,
|
||||
get_extension_from_mime,
|
||||
is_google_workspace_file,
|
||||
should_skip_by_extension,
|
||||
should_skip_file,
|
||||
)
|
||||
|
||||
|
|
@ -45,6 +43,11 @@ async def download_and_extract_content(
|
|||
if should_skip_file(mime_type):
|
||||
return None, {}, f"Skipping {mime_type}"
|
||||
|
||||
if not is_google_workspace_file(mime_type):
|
||||
ext_skip, _unsup_ext = should_skip_by_extension(file_name)
|
||||
if ext_skip:
|
||||
return None, {}, f"Skipping unsupported extension: {file_name}"
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
|
||||
|
||||
drive_metadata: dict[str, Any] = {
|
||||
|
|
@ -97,7 +100,10 @@ async def download_and_extract_content(
|
|||
if error:
|
||||
return None, drive_metadata, error
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
etl_filename = (
|
||||
file_name + extension if is_google_workspace_file(mime_type) else file_name
|
||||
)
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, etl_filename)
|
||||
return markdown, drive_metadata, None
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -110,99 +116,14 @@ async def download_and_extract_content(
|
|||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service."""
|
||||
lower = filename.lower()
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(
|
||||
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
# Document files -- use configured ETL service
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(
|
||||
file_path=file_path, estimated_pages=50
|
||||
)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[docling] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(
|
||||
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
||||
|
||||
async def download_and_process_file(
|
||||
|
|
@ -236,10 +157,14 @@ async def download_and_process_file(
|
|||
file_name = file.get("name", "Unknown")
|
||||
mime_type = file.get("mimeType", "")
|
||||
|
||||
# Skip folders and shortcuts
|
||||
if should_skip_file(mime_type):
|
||||
return None, f"Skipping {mime_type}", None
|
||||
|
||||
if not is_google_workspace_file(mime_type):
|
||||
ext_skip, _unsup_ext = should_skip_by_extension(file_name)
|
||||
if ext_skip:
|
||||
return None, f"Skipping unsupported extension: {file_name}", None
|
||||
|
||||
logger.info(f"Downloading file: {file_name} ({mime_type})")
|
||||
|
||||
temp_file_path = None
|
||||
|
|
@ -310,10 +235,13 @@ async def download_and_process_file(
|
|||
"."
|
||||
)[-1]
|
||||
|
||||
etl_filename = (
|
||||
file_name + extension if is_google_workspace_file(mime_type) else file_name
|
||||
)
|
||||
logger.info(f"Processing {file_name} with Surfsense's file processor")
|
||||
await process_file_in_background(
|
||||
file_path=temp_file_path,
|
||||
filename=file_name,
|
||||
filename=etl_filename,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""File type handlers for Google Drive."""
|
||||
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
GOOGLE_DOC = "application/vnd.google-apps.document"
|
||||
GOOGLE_SHEET = "application/vnd.google-apps.spreadsheet"
|
||||
GOOGLE_SLIDE = "application/vnd.google-apps.presentation"
|
||||
|
|
@ -46,6 +48,21 @@ def should_skip_file(mime_type: str) -> bool:
|
|||
return mime_type in [GOOGLE_FOLDER, GOOGLE_SHORTCUT]
|
||||
|
||||
|
||||
def should_skip_by_extension(filename: str) -> tuple[bool, str | None]:
|
||||
"""Check if the file extension is not parseable by the configured ETL service.
|
||||
|
||||
Returns (should_skip, unsupported_extension_or_None).
|
||||
"""
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
if should_skip_for_service(filename, app_config.ETL_SERVICE):
|
||||
ext = PurePosixPath(filename).suffix.lower()
|
||||
return True, ext
|
||||
return False, None
|
||||
|
||||
|
||||
def get_export_mime_type(mime_type: str) -> str | None:
|
||||
"""Get export MIME type for Google Workspace files."""
|
||||
return EXPORT_FORMATS.get(mime_type)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,9 @@
|
|||
"""Content extraction for OneDrive files.
|
||||
"""Content extraction for OneDrive files."""
|
||||
|
||||
Reuses the same ETL parsing logic as Google Drive since file parsing is
|
||||
extension-based, not provider-specific.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -31,7 +24,8 @@ async def download_and_extract_content(
|
|||
item_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if should_skip_file(file):
|
||||
skip, _unsup_ext = should_skip_file(file)
|
||||
if skip:
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
file_info = file.get("file", {})
|
||||
|
|
@ -84,98 +78,11 @@ async def download_and_extract_content(
|
|||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service.
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
Same logic as Google Drive -- file parsing is extension-based.
|
||||
"""
|
||||
lower = filename.lower()
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(
|
||||
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(
|
||||
file_path=file_path, estimated_pages=50
|
||||
)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[docling] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(
|
||||
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""File type handlers for Microsoft OneDrive."""
|
||||
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
ONEDRIVE_FOLDER_FACET = "folder"
|
||||
ONENOTE_MIME = "application/msonenote"
|
||||
|
||||
|
|
@ -38,13 +40,28 @@ def is_folder(item: dict) -> bool:
|
|||
return ONEDRIVE_FOLDER_FACET in item
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
"""Skip folders, OneNote files, remote items (shared links), and packages."""
|
||||
def should_skip_file(item: dict) -> tuple[bool, str | None]:
|
||||
"""Skip folders, OneNote files, remote items, packages, and unsupported extensions.
|
||||
|
||||
Returns (should_skip, unsupported_extension_or_None).
|
||||
The second element is only set when the skip is due to an unsupported extension.
|
||||
"""
|
||||
if is_folder(item):
|
||||
return True
|
||||
return True, None
|
||||
if "remoteItem" in item:
|
||||
return True
|
||||
return True, None
|
||||
if "package" in item:
|
||||
return True
|
||||
return True, None
|
||||
mime = item.get("file", {}).get("mimeType", "")
|
||||
return mime in SKIP_MIME_TYPES
|
||||
if mime in SKIP_MIME_TYPES:
|
||||
return True, None
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
name = item.get("name", "")
|
||||
if should_skip_for_service(name, app_config.ETL_SERVICE):
|
||||
ext = PurePosixPath(name).suffix.lower()
|
||||
return True, ext
|
||||
return False, None
|
||||
|
|
|
|||
|
|
@ -71,8 +71,10 @@ async def get_files_in_folder(
|
|||
)
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
else:
|
||||
skip, _unsup_ext = should_skip_file(item)
|
||||
if not skip:
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class DocumentType(StrEnum):
|
|||
COMPOSIO_GOOGLE_DRIVE_CONNECTOR = "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"
|
||||
COMPOSIO_GMAIL_CONNECTOR = "COMPOSIO_GMAIL_CONNECTOR"
|
||||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
LOCAL_FOLDER_FILE = "LOCAL_FOLDER_FILE"
|
||||
|
||||
|
||||
# Native Google document types → their legacy Composio equivalents.
|
||||
|
|
@ -259,6 +260,24 @@ class ImageGenProvider(StrEnum):
|
|||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class VisionProvider(StrEnum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
OLLAMA = "OLLAMA"
|
||||
GROQ = "GROQ"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
MISTRAL = "MISTRAL"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class LogLevel(StrEnum):
|
||||
DEBUG = "DEBUG"
|
||||
INFO = "INFO"
|
||||
|
|
@ -376,6 +395,11 @@ class Permission(StrEnum):
|
|||
IMAGE_GENERATIONS_READ = "image_generations:read"
|
||||
IMAGE_GENERATIONS_DELETE = "image_generations:delete"
|
||||
|
||||
# Vision LLM Configs
|
||||
VISION_CONFIGS_CREATE = "vision_configs:create"
|
||||
VISION_CONFIGS_READ = "vision_configs:read"
|
||||
VISION_CONFIGS_DELETE = "vision_configs:delete"
|
||||
|
||||
# Connectors
|
||||
CONNECTORS_CREATE = "connectors:create"
|
||||
CONNECTORS_READ = "connectors:read"
|
||||
|
|
@ -444,6 +468,9 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
# Image Generations (create and read, no delete)
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Vision Configs (create and read, no delete)
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
# Connectors (no delete)
|
||||
Permission.CONNECTORS_CREATE.value,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
|
|
@ -477,6 +504,8 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.VIDEO_PRESENTATIONS_READ.value,
|
||||
# Image Generations (read only)
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Vision Configs (read only)
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
# Connectors (read only)
|
||||
Permission.CONNECTORS_READ.value,
|
||||
# Logs (read only)
|
||||
|
|
@ -955,6 +984,7 @@ class Folder(BaseModel, TimestampMixin):
|
|||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
folder_metadata = Column("metadata", JSONB, nullable=True)
|
||||
|
||||
parent = relationship("Folder", remote_side="Folder.id", backref="children")
|
||||
search_space = relationship("SearchSpace", back_populates="folders")
|
||||
|
|
@ -1039,6 +1069,26 @@ class Document(BaseModel, TimestampMixin):
|
|||
)
|
||||
|
||||
|
||||
class DocumentVersion(BaseModel, TimestampMixin):
|
||||
__tablename__ = "document_versions"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("document_id", "version_number", name="uq_document_version"),
|
||||
)
|
||||
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
version_number = Column(Integer, nullable=False)
|
||||
source_markdown = Column(Text, nullable=True)
|
||||
content_hash = Column(String, nullable=False)
|
||||
title = Column(String, nullable=True)
|
||||
|
||||
document = relationship("Document", backref="versions")
|
||||
|
||||
|
||||
class Chunk(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chunks"
|
||||
|
||||
|
|
@ -1241,6 +1291,33 @@ class ImageGenerationConfig(BaseModel, TimestampMixin):
|
|||
user = relationship("User", back_populates="image_generation_configs")
|
||||
|
||||
|
||||
class VisionLLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "vision_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True)
|
||||
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="vision_llm_configs")
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores image generation requests and results using litellm.aimage_generation().
|
||||
|
|
@ -1329,6 +1406,9 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
|
|
@ -1407,6 +1487,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="VisionLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# RBAC relationships
|
||||
roles = relationship(
|
||||
|
|
@ -1936,6 +2022,12 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
|
|
@ -2050,6 +2142,12 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# User memories for personalized AI responses
|
||||
memories = relationship(
|
||||
"UserMemory",
|
||||
|
|
|
|||
0
surfsense_backend/app/etl_pipeline/__init__.py
Normal file
0
surfsense_backend/app/etl_pipeline/__init__.py
Normal file
39
surfsense_backend/app/etl_pipeline/constants.py
Normal file
39
surfsense_backend/app/etl_pipeline/constants.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import ssl
|
||||
|
||||
import httpx
|
||||
|
||||
LLAMACLOUD_MAX_RETRIES = 5
|
||||
LLAMACLOUD_BASE_DELAY = 10
|
||||
LLAMACLOUD_MAX_DELAY = 120
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
||||
ssl.SSLError,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.LocalProtocolError,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
TimeoutError,
|
||||
OSError,
|
||||
)
|
||||
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW = 100 * 1024
|
||||
MIN_UPLOAD_TIMEOUT = 120
|
||||
MAX_UPLOAD_TIMEOUT = 1800
|
||||
BASE_JOB_TIMEOUT = 600
|
||||
PER_PAGE_JOB_TIMEOUT = 60
|
||||
|
||||
|
||||
def calculate_upload_timeout(file_size_bytes: int) -> float:
|
||||
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
|
||||
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
|
||||
|
||||
|
||||
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
|
||||
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
|
||||
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
|
||||
return max(page_based_timeout, size_based_timeout)
|
||||
21
surfsense_backend/app/etl_pipeline/etl_document.py
Normal file
21
surfsense_backend/app/etl_pipeline/etl_document.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class EtlRequest(BaseModel):
|
||||
file_path: str
|
||||
filename: str
|
||||
estimated_pages: int = 0
|
||||
|
||||
@field_validator("filename")
|
||||
@classmethod
|
||||
def filename_must_not_be_empty(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("filename must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
class EtlResult(BaseModel):
|
||||
markdown_content: str
|
||||
etl_service: str
|
||||
actual_pages: int = 0
|
||||
content_type: str
|
||||
125
surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
Normal file
125
surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
import logging
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.etl_pipeline.etl_document import EtlRequest, EtlResult
|
||||
from app.etl_pipeline.exceptions import (
|
||||
EtlServiceUnavailableError,
|
||||
EtlUnsupportedFileError,
|
||||
)
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
from app.etl_pipeline.parsers.audio import transcribe_audio
|
||||
from app.etl_pipeline.parsers.direct_convert import convert_file_directly
|
||||
from app.etl_pipeline.parsers.plaintext import read_plaintext
|
||||
|
||||
|
||||
class EtlPipelineService:
|
||||
"""Single pipeline for extracting markdown from files. All callers use this."""
|
||||
|
||||
async def extract(self, request: EtlRequest) -> EtlResult:
|
||||
category = classify_file(request.filename)
|
||||
|
||||
if category == FileCategory.UNSUPPORTED:
|
||||
raise EtlUnsupportedFileError(
|
||||
f"File type not supported for parsing: {request.filename}"
|
||||
)
|
||||
|
||||
if category == FileCategory.PLAINTEXT:
|
||||
content = read_plaintext(request.file_path)
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service="PLAINTEXT",
|
||||
content_type="plaintext",
|
||||
)
|
||||
|
||||
if category == FileCategory.DIRECT_CONVERT:
|
||||
content = convert_file_directly(request.file_path, request.filename)
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service="DIRECT_CONVERT",
|
||||
content_type="direct_convert",
|
||||
)
|
||||
|
||||
if category == FileCategory.AUDIO:
|
||||
content = await transcribe_audio(request.file_path, request.filename)
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service="AUDIO",
|
||||
content_type="audio",
|
||||
)
|
||||
|
||||
return await self._extract_document(request)
|
||||
|
||||
async def _extract_document(self, request: EtlRequest) -> EtlResult:
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.utils.file_extensions import get_document_extensions_for_service
|
||||
|
||||
etl_service = app_config.ETL_SERVICE
|
||||
if not etl_service:
|
||||
raise EtlServiceUnavailableError(
|
||||
"No ETL_SERVICE configured. "
|
||||
"Set ETL_SERVICE to UNSTRUCTURED, LLAMACLOUD, or DOCLING in your .env"
|
||||
)
|
||||
|
||||
ext = PurePosixPath(request.filename).suffix.lower()
|
||||
supported = get_document_extensions_for_service(etl_service)
|
||||
if ext not in supported:
|
||||
raise EtlUnsupportedFileError(
|
||||
f"File type {ext} is not supported by {etl_service}"
|
||||
)
|
||||
|
||||
if etl_service == "DOCLING":
|
||||
from app.etl_pipeline.parsers.docling import parse_with_docling
|
||||
|
||||
content = await parse_with_docling(request.file_path, request.filename)
|
||||
elif etl_service == "UNSTRUCTURED":
|
||||
from app.etl_pipeline.parsers.unstructured import parse_with_unstructured
|
||||
|
||||
content = await parse_with_unstructured(request.file_path)
|
||||
elif etl_service == "LLAMACLOUD":
|
||||
content = await self._extract_with_llamacloud(request)
|
||||
else:
|
||||
raise EtlServiceUnavailableError(f"Unknown ETL_SERVICE: {etl_service}")
|
||||
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service=etl_service,
|
||||
content_type="document",
|
||||
)
|
||||
|
||||
async def _extract_with_llamacloud(self, request: EtlRequest) -> str:
|
||||
"""Try Azure Document Intelligence first (when configured) then LlamaCloud.
|
||||
|
||||
Azure DI is an internal accelerator: cheaper and faster for its supported
|
||||
file types. If it is not configured, or the file extension is not in
|
||||
Azure DI's supported set, LlamaCloud is used directly. If Azure DI
|
||||
fails for any reason, LlamaCloud is used as a fallback.
|
||||
"""
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.utils.file_extensions import AZURE_DI_DOCUMENT_EXTENSIONS
|
||||
|
||||
ext = PurePosixPath(request.filename).suffix.lower()
|
||||
azure_configured = bool(
|
||||
getattr(app_config, "AZURE_DI_ENDPOINT", None)
|
||||
and getattr(app_config, "AZURE_DI_KEY", None)
|
||||
)
|
||||
|
||||
if azure_configured and ext in AZURE_DI_DOCUMENT_EXTENSIONS:
|
||||
try:
|
||||
from app.etl_pipeline.parsers.azure_doc_intelligence import (
|
||||
parse_with_azure_doc_intelligence,
|
||||
)
|
||||
|
||||
return await parse_with_azure_doc_intelligence(request.file_path)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"Azure Document Intelligence failed for %s, "
|
||||
"falling back to LlamaCloud",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
from app.etl_pipeline.parsers.llamacloud import parse_with_llamacloud
|
||||
|
||||
return await parse_with_llamacloud(request.file_path, request.estimated_pages)
|
||||
10
surfsense_backend/app/etl_pipeline/exceptions.py
Normal file
10
surfsense_backend/app/etl_pipeline/exceptions.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
class EtlParseError(Exception):
|
||||
"""Raised when an ETL parser fails to produce content."""
|
||||
|
||||
|
||||
class EtlServiceUnavailableError(Exception):
|
||||
"""Raised when the configured ETL_SERVICE is not recognised."""
|
||||
|
||||
|
||||
class EtlUnsupportedFileError(Exception):
|
||||
"""Raised when a file type cannot be parsed by any ETL pipeline."""
|
||||
137
surfsense_backend/app/etl_pipeline/file_classifier.py
Normal file
137
surfsense_backend/app/etl_pipeline/file_classifier.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from enum import Enum
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.utils.file_extensions import (
|
||||
DOCUMENT_EXTENSIONS,
|
||||
get_document_extensions_for_service,
|
||||
)
|
||||
|
||||
PLAINTEXT_EXTENSIONS = frozenset(
|
||||
{
|
||||
".md",
|
||||
".markdown",
|
||||
".txt",
|
||||
".text",
|
||||
".json",
|
||||
".jsonl",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".xml",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".sass",
|
||||
".py",
|
||||
".pyw",
|
||||
".pyi",
|
||||
".pyx",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
".tsx",
|
||||
".mjs",
|
||||
".cjs",
|
||||
".java",
|
||||
".kt",
|
||||
".kts",
|
||||
".scala",
|
||||
".groovy",
|
||||
".c",
|
||||
".h",
|
||||
".cpp",
|
||||
".cxx",
|
||||
".cc",
|
||||
".hpp",
|
||||
".hxx",
|
||||
".cs",
|
||||
".fs",
|
||||
".fsx",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".pl",
|
||||
".pm",
|
||||
".lua",
|
||||
".swift",
|
||||
".m",
|
||||
".mm",
|
||||
".r",
|
||||
".jl",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".bat",
|
||||
".cmd",
|
||||
".ps1",
|
||||
".sql",
|
||||
".graphql",
|
||||
".gql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerignore",
|
||||
".editorconfig",
|
||||
".makefile",
|
||||
".cmake",
|
||||
".log",
|
||||
".rst",
|
||||
".tex",
|
||||
".bib",
|
||||
".org",
|
||||
".adoc",
|
||||
".asciidoc",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".tf",
|
||||
".hcl",
|
||||
".proto",
|
||||
}
|
||||
)
|
||||
|
||||
AUDIO_EXTENSIONS = frozenset(
|
||||
{".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"}
|
||||
)
|
||||
|
||||
DIRECT_CONVERT_EXTENSIONS = frozenset({".csv", ".tsv", ".html", ".htm", ".xhtml"})
|
||||
|
||||
|
||||
class FileCategory(Enum):
|
||||
PLAINTEXT = "plaintext"
|
||||
AUDIO = "audio"
|
||||
DIRECT_CONVERT = "direct_convert"
|
||||
UNSUPPORTED = "unsupported"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
def classify_file(filename: str) -> FileCategory:
|
||||
suffix = PurePosixPath(filename).suffix.lower()
|
||||
if suffix in PLAINTEXT_EXTENSIONS:
|
||||
return FileCategory.PLAINTEXT
|
||||
if suffix in AUDIO_EXTENSIONS:
|
||||
return FileCategory.AUDIO
|
||||
if suffix in DIRECT_CONVERT_EXTENSIONS:
|
||||
return FileCategory.DIRECT_CONVERT
|
||||
if suffix in DOCUMENT_EXTENSIONS:
|
||||
return FileCategory.DOCUMENT
|
||||
return FileCategory.UNSUPPORTED
|
||||
|
||||
|
||||
def should_skip_for_service(filename: str, etl_service: str | None) -> bool:
|
||||
"""Return True if *filename* cannot be processed by *etl_service*.
|
||||
|
||||
Plaintext, audio, and direct-convert files are parser-agnostic and never
|
||||
skipped. Document files are checked against the per-parser extension set.
|
||||
"""
|
||||
category = classify_file(filename)
|
||||
if category == FileCategory.UNSUPPORTED:
|
||||
return True
|
||||
if category == FileCategory.DOCUMENT:
|
||||
suffix = PurePosixPath(filename).suffix.lower()
|
||||
return suffix not in get_document_extensions_for_service(etl_service)
|
||||
return False
|
||||
0
surfsense_backend/app/etl_pipeline/parsers/__init__.py
Normal file
0
surfsense_backend/app/etl_pipeline/parsers/__init__.py
Normal file
34
surfsense_backend/app/etl_pipeline/parsers/audio.py
Normal file
34
surfsense_backend/app/etl_pipeline/parsers/audio.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
|
||||
async def transcribe_audio(file_path: str, filename: str) -> str:
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
result = stt_service.transcribe_file(file_path)
|
||||
text = result.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
response = await atranscription(**kwargs)
|
||||
text = response.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
MAX_RETRIES = 5
|
||||
BASE_DELAY = 10
|
||||
MAX_DELAY = 120
|
||||
|
||||
|
||||
async def parse_with_azure_doc_intelligence(file_path: str) -> str:
|
||||
from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
|
||||
from azure.ai.documentintelligence.models import DocumentContentFormat
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import (
|
||||
ClientAuthenticationError,
|
||||
HttpResponseError,
|
||||
ServiceRequestError,
|
||||
ServiceResponseError,
|
||||
)
|
||||
|
||||
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
||||
retryable_exceptions = (ServiceRequestError, ServiceResponseError)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
client = DocumentIntelligenceClient(
|
||||
endpoint=app_config.AZURE_DI_ENDPOINT,
|
||||
credential=AzureKeyCredential(app_config.AZURE_DI_KEY),
|
||||
)
|
||||
async with client:
|
||||
with open(file_path, "rb") as f:
|
||||
poller = await client.begin_analyze_document(
|
||||
"prebuilt-read",
|
||||
body=f,
|
||||
output_content_format=DocumentContentFormat.MARKDOWN,
|
||||
)
|
||||
result = await poller.result()
|
||||
|
||||
if attempt > 1:
|
||||
logging.info(
|
||||
f"Azure Document Intelligence succeeded on attempt {attempt} "
|
||||
f"after {len(attempt_errors)} failures"
|
||||
)
|
||||
|
||||
if not result.content:
|
||||
return ""
|
||||
|
||||
return result.content
|
||||
|
||||
except ClientAuthenticationError:
|
||||
raise
|
||||
except HttpResponseError as e:
|
||||
if e.status_code and 400 <= e.status_code < 500:
|
||||
raise
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
except retryable_exceptions as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
|
||||
if attempt < MAX_RETRIES:
|
||||
base_delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
||||
jitter = base_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = base_delay + jitter
|
||||
|
||||
logging.warning(
|
||||
f"Azure Document Intelligence failed "
|
||||
f"(attempt {attempt}/{MAX_RETRIES}): "
|
||||
f"{attempt_errors[-1]}. File: {file_size_mb:.1f}MB. "
|
||||
f"Retrying in {delay:.0f}s..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"Azure Document Intelligence failed after {MAX_RETRIES} "
|
||||
f"attempts. File size: {file_size_mb:.1f}MB. "
|
||||
f"Errors: {'; '.join(attempt_errors)}"
|
||||
)
|
||||
|
||||
raise last_exception or RuntimeError(
|
||||
f"Azure Document Intelligence parsing failed after {MAX_RETRIES} retries. "
|
||||
f"File size: {file_size_mb:.1f}MB"
|
||||
)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from app.tasks.document_processors._direct_converters import convert_file_directly
|
||||
|
||||
__all__ = ["convert_file_directly"]
|
||||
26
surfsense_backend/app/etl_pipeline/parsers/docling.py
Normal file
26
surfsense_backend/app/etl_pipeline/parsers/docling.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import warnings
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
|
||||
async def parse_with_docling(file_path: str, filename: str) -> str:
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
pdfminer_logger = getLogger("pdfminer")
|
||||
original_level = pdfminer_logger.level
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*Cannot set gray non-stroke color.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*invalid float value.*")
|
||||
pdfminer_logger.setLevel(ERROR)
|
||||
|
||||
try:
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
finally:
|
||||
pdfminer_logger.setLevel(original_level)
|
||||
|
||||
return result["content"]
|
||||
123
surfsense_backend/app/etl_pipeline/parsers/llamacloud.py
Normal file
123
surfsense_backend/app/etl_pipeline/parsers/llamacloud.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.etl_pipeline.constants import (
|
||||
LLAMACLOUD_BASE_DELAY,
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
LLAMACLOUD_MAX_RETRIES,
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
calculate_job_timeout,
|
||||
calculate_upload_timeout,
|
||||
)
|
||||
|
||||
|
||||
async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
file_size_bytes = os.path.getsize(file_path)
|
||||
file_size_mb = file_size_bytes / (1024 * 1024)
|
||||
|
||||
upload_timeout = calculate_upload_timeout(file_size_bytes)
|
||||
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
|
||||
|
||||
custom_timeout = httpx.Timeout(
|
||||
connect=120.0,
|
||||
read=upload_timeout,
|
||||
write=upload_timeout,
|
||||
pool=120.0,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
|
||||
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
|
||||
f"job_timeout={job_timeout:.0f}s"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
max_timeout=int(max(2000, job_timeout + upload_timeout)),
|
||||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
|
||||
custom_client=custom_client,
|
||||
)
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
if attempt > 1:
|
||||
logging.info(
|
||||
f"LlamaCloud upload succeeded on attempt {attempt} after "
|
||||
f"{len(attempt_errors)} failures"
|
||||
)
|
||||
|
||||
if hasattr(result, "get_markdown_documents"):
|
||||
markdown_docs = result.get_markdown_documents(split_by_page=False)
|
||||
if markdown_docs and hasattr(markdown_docs[0], "text"):
|
||||
return markdown_docs[0].text
|
||||
if hasattr(result, "pages") and result.pages:
|
||||
return "\n\n".join(
|
||||
p.md for p in result.pages if hasattr(p, "md") and p.md
|
||||
)
|
||||
return str(result)
|
||||
|
||||
if isinstance(result, list):
|
||||
if result and hasattr(result[0], "text"):
|
||||
return result[0].text
|
||||
return "\n\n".join(
|
||||
doc.page_content if hasattr(doc, "page_content") else str(doc)
|
||||
for doc in result
|
||||
)
|
||||
|
||||
return str(result)
|
||||
|
||||
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
|
||||
if attempt < LLAMACLOUD_MAX_RETRIES:
|
||||
base_delay = min(
|
||||
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
)
|
||||
jitter = base_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = base_delay + jitter
|
||||
|
||||
logging.warning(
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
|
||||
f"{error_type}. File: {file_size_mb:.1f}MB. "
|
||||
f"Retrying in {delay:.0f}s..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
|
||||
f"attempts. File size: {file_size_mb:.1f}MB, "
|
||||
f"Pages: {estimated_pages}. "
|
||||
f"Errors: {'; '.join(attempt_errors)}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
raise last_exception or RuntimeError(
|
||||
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
|
||||
f"File size: {file_size_mb:.1f}MB"
|
||||
)
|
||||
8
surfsense_backend/app/etl_pipeline/parsers/plaintext.py
Normal file
8
surfsense_backend/app/etl_pipeline/parsers/plaintext.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
def read_plaintext(file_path: str) -> str:
|
||||
with open(file_path, encoding="utf-8", errors="replace") as f:
|
||||
content = f.read()
|
||||
if "\x00" in content:
|
||||
raise ValueError(
|
||||
f"File contains null bytes — likely a binary file opened as text: {file_path}"
|
||||
)
|
||||
return content
|
||||
14
surfsense_backend/app/etl_pipeline/parsers/unstructured.py
Normal file
14
surfsense_backend/app/etl_pipeline/parsers/unstructured.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
async def parse_with_unstructured(file_path: str) -> str:
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return "\n\n".join(doc.page_content for doc in docs if doc.page_content)
|
||||
|
|
@ -59,7 +59,7 @@ class PipelineMessages:
|
|||
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_NOT_FOUND = "Model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
|
|
@ -67,7 +67,7 @@ class PipelineMessages:
|
|||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_NOT_FOUND = "Model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
|
|
|
|||
|
|
@ -273,17 +273,18 @@ class IndexingPipelineService:
|
|||
continue
|
||||
|
||||
dup_check = await self.session.execute(
|
||||
select(Document.id).filter(
|
||||
select(Document.id, Document.title).filter(
|
||||
Document.content_hash == content_hash,
|
||||
Document.id != existing.id,
|
||||
)
|
||||
)
|
||||
if dup_check.scalars().first() is not None:
|
||||
dup_row = dup_check.first()
|
||||
if dup_row is not None:
|
||||
if not DocumentStatus.is_state(
|
||||
existing.status, DocumentStatus.READY
|
||||
):
|
||||
existing.status = DocumentStatus.failed(
|
||||
"Duplicate content — already indexed by another document"
|
||||
f"Duplicate content: matches '{dup_row.title}'"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi import APIRouter
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from .autocomplete_routes import router as autocomplete_router
|
||||
from .chat_comments_routes import router as chat_comments_router
|
||||
from .circleback_webhook_route import router as circleback_webhook_router
|
||||
from .clickup_add_connector_route import router as clickup_add_connector_router
|
||||
|
|
@ -48,6 +49,7 @@ from .stripe_routes import router as stripe_router
|
|||
from .surfsense_docs_routes import router as surfsense_docs_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
from .vision_llm_routes import router as vision_llm_router
|
||||
from .youtube_routes import router as youtube_router
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -67,6 +69,7 @@ router.include_router(
|
|||
) # Video presentation status and streaming
|
||||
router.include_router(reports_router) # Report CRUD and multi-format export
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
@ -84,7 +87,7 @@ router.include_router(confluence_add_connector_router)
|
|||
router.include_router(clickup_add_connector_router)
|
||||
router.include_router(dropbox_add_connector_router)
|
||||
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
|
||||
router.include_router(model_list_router) # Dynamic LLM model catalogue from OpenRouter
|
||||
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
|
||||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
|
|
@ -95,3 +98,4 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
|
|||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -26,7 +24,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_pkce_pair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -75,28 +77,6 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
|||
return f"Basic {b64}"
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""
|
||||
Generate PKCE code verifier and code challenge.
|
||||
|
||||
Returns:
|
||||
Tuple of (code_verifier, code_challenge)
|
||||
"""
|
||||
# Generate code verifier (43-128 characters)
|
||||
code_verifier = (
|
||||
base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
|
||||
)
|
||||
|
||||
# Generate code challenge (SHA256 hash of verifier, base64url encoded)
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
@router.get("/auth/airtable/connector/add")
|
||||
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""
|
||||
|
|
|
|||
45
surfsense_backend/app/routes/autocomplete_routes.py
Normal file
45
surfsense_backend/app/routes/autocomplete_routes.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.services.vision_autocomplete_service import stream_vision_autocomplete
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
router = APIRouter(prefix="/autocomplete", tags=["autocomplete"])
|
||||
|
||||
MAX_SCREENSHOT_SIZE = 20 * 1024 * 1024 # 20 MB base64 ceiling
|
||||
|
||||
|
||||
class VisionAutocompleteRequest(BaseModel):
|
||||
screenshot: str = Field(..., max_length=MAX_SCREENSHOT_SIZE)
|
||||
search_space_id: int
|
||||
app_name: str = ""
|
||||
window_title: str = ""
|
||||
|
||||
|
||||
@router.post("/vision/stream")
|
||||
async def vision_autocomplete_stream(
|
||||
body: VisionAutocompleteRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_vision_autocomplete(
|
||||
body.screenshot,
|
||||
body.search_space_id,
|
||||
session,
|
||||
app_name=body.app_name,
|
||||
window_title=body.window_title,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
**VercelStreamingService.get_response_headers(),
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, UploadFile
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
|
@ -10,6 +11,8 @@ from app.db import (
|
|||
Chunk,
|
||||
Document,
|
||||
DocumentType,
|
||||
DocumentVersion,
|
||||
Folder,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
|
|
@ -27,6 +30,7 @@ from app.schemas import (
|
|||
DocumentTitleSearchResponse,
|
||||
DocumentUpdate,
|
||||
DocumentWithChunksRead,
|
||||
FolderRead,
|
||||
PaginatedResponse,
|
||||
)
|
||||
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
|
||||
|
|
@ -957,6 +961,39 @@ async def get_document_by_chunk_id(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/watched-folders", response_model=list[FolderRead])
|
||||
async def get_watched_folders(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
folders = (
|
||||
(
|
||||
await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.parent_id.is_(None),
|
||||
Folder.folder_metadata.isnot(None),
|
||||
Folder.folder_metadata["watched"].astext == "true",
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return folders
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_id}/chunks",
|
||||
response_model=PaginatedResponse[ChunkRead],
|
||||
|
|
@ -1212,3 +1249,297 @@ async def delete_document(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# ====================================================================
|
||||
# Version History Endpoints
|
||||
# ====================================================================
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/versions")
|
||||
async def list_document_versions(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List all versions for a document, ordered by version_number descending."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
).scalar_one_or_none()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, document.search_space_id, Permission.DOCUMENTS_READ.value
|
||||
)
|
||||
|
||||
versions = (
|
||||
(
|
||||
await session.execute(
|
||||
select(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document_id)
|
||||
.order_by(DocumentVersion.version_number.desc())
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"version_number": v.version_number,
|
||||
"title": v.title,
|
||||
"content_hash": v.content_hash,
|
||||
"created_at": v.created_at.isoformat() if v.created_at else None,
|
||||
}
|
||||
for v in versions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/versions/{version_number}")
|
||||
async def get_document_version(
|
||||
document_id: int,
|
||||
version_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get full version content including source_markdown."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
).scalar_one_or_none()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, document.search_space_id, Permission.DOCUMENTS_READ.value
|
||||
)
|
||||
|
||||
version = (
|
||||
await session.execute(
|
||||
select(DocumentVersion).where(
|
||||
DocumentVersion.document_id == document_id,
|
||||
DocumentVersion.version_number == version_number,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
return {
|
||||
"version_number": version.version_number,
|
||||
"title": version.title,
|
||||
"content_hash": version.content_hash,
|
||||
"source_markdown": version.source_markdown,
|
||||
"created_at": version.created_at.isoformat() if version.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/documents/{document_id}/versions/{version_number}/restore")
|
||||
async def restore_document_version(
|
||||
document_id: int,
|
||||
version_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Restore a previous version: snapshot current state, then overwrite document content."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
).scalar_one_or_none()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, document.search_space_id, Permission.DOCUMENTS_UPDATE.value
|
||||
)
|
||||
|
||||
version = (
|
||||
await session.execute(
|
||||
select(DocumentVersion).where(
|
||||
DocumentVersion.document_id == document_id,
|
||||
DocumentVersion.version_number == version_number,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# Snapshot current state before restoring
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
await create_version_snapshot(session, document)
|
||||
|
||||
# Restore the version's content onto the document
|
||||
document.source_markdown = version.source_markdown
|
||||
document.title = version.title or document.title
|
||||
document.content_needs_reindexing = True
|
||||
await session.commit()
|
||||
|
||||
from app.tasks.celery_tasks.document_reindex_tasks import reindex_document_task
|
||||
|
||||
reindex_document_task.delay(document_id, str(user.id))
|
||||
|
||||
return {
|
||||
"message": f"Restored version {version_number}",
|
||||
"document_id": document_id,
|
||||
"restored_version": version_number,
|
||||
}
|
||||
|
||||
|
||||
# ===== Local folder indexing endpoints =====
|
||||
|
||||
|
||||
class FolderIndexRequest(PydanticBaseModel):
|
||||
folder_path: str
|
||||
folder_name: str
|
||||
search_space_id: int
|
||||
exclude_patterns: list[str] | None = None
|
||||
file_extensions: list[str] | None = None
|
||||
root_folder_id: int | None = None
|
||||
enable_summary: bool = False
|
||||
|
||||
|
||||
class FolderIndexFilesRequest(PydanticBaseModel):
|
||||
folder_path: str
|
||||
folder_name: str
|
||||
search_space_id: int
|
||||
target_file_paths: list[str]
|
||||
root_folder_id: int | None = None
|
||||
enable_summary: bool = False
|
||||
|
||||
|
||||
@router.post("/documents/folder-index")
|
||||
async def folder_index(
|
||||
request: FolderIndexRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Full-scan index of a local folder. Creates the root Folder row synchronously
|
||||
and dispatches the heavy indexing work to a Celery task.
|
||||
Returns the root_folder_id so the desktop can persist it.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
if not app_config.is_self_hosted():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Local folder indexing is only available in self-hosted mode",
|
||||
)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
)
|
||||
|
||||
watched_metadata = {
|
||||
"watched": True,
|
||||
"folder_path": request.folder_path,
|
||||
"exclude_patterns": request.exclude_patterns,
|
||||
"file_extensions": request.file_extensions,
|
||||
}
|
||||
|
||||
root_folder_id = request.root_folder_id
|
||||
if root_folder_id:
|
||||
existing = (
|
||||
await session.execute(select(Folder).where(Folder.id == root_folder_id))
|
||||
).scalar_one_or_none()
|
||||
if not existing:
|
||||
root_folder_id = None
|
||||
else:
|
||||
existing.folder_metadata = watched_metadata
|
||||
await session.commit()
|
||||
|
||||
if not root_folder_id:
|
||||
root_folder = Folder(
|
||||
name=request.folder_name,
|
||||
search_space_id=request.search_space_id,
|
||||
created_by_id=str(user.id),
|
||||
position="a0",
|
||||
folder_metadata=watched_metadata,
|
||||
)
|
||||
session.add(root_folder)
|
||||
await session.flush()
|
||||
root_folder_id = root_folder.id
|
||||
await session.commit()
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import index_local_folder_task
|
||||
|
||||
index_local_folder_task.delay(
|
||||
search_space_id=request.search_space_id,
|
||||
user_id=str(user.id),
|
||||
folder_path=request.folder_path,
|
||||
folder_name=request.folder_name,
|
||||
exclude_patterns=request.exclude_patterns,
|
||||
file_extensions=request.file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=request.enable_summary,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Folder indexing started",
|
||||
"status": "processing",
|
||||
"root_folder_id": root_folder_id,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/documents/folder-index-files")
|
||||
async def folder_index_files(
|
||||
request: FolderIndexFilesRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Index multiple files within a watched folder (batched chokidar trigger).
|
||||
Validates that all target_file_paths are under folder_path.
|
||||
Dispatches a single Celery task that processes them in parallel.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
if not app_config.is_self_hosted():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Local folder indexing is only available in self-hosted mode",
|
||||
)
|
||||
|
||||
if not request.target_file_paths:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="target_file_paths must not be empty"
|
||||
)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
for fp in request.target_file_paths:
|
||||
try:
|
||||
Path(fp).relative_to(request.folder_path)
|
||||
except ValueError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"target_file_path {fp} must be inside folder_path",
|
||||
) from err
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import index_local_folder_task
|
||||
|
||||
index_local_folder_task.delay(
|
||||
search_space_id=request.search_space_id,
|
||||
user_id=str(user.id),
|
||||
folder_path=request.folder_path,
|
||||
folder_name=request.folder_name,
|
||||
target_file_paths=request.target_file_paths,
|
||||
root_folder_id=request.root_folder_id,
|
||||
enable_summary=request.enable_summary,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Batch indexing started for {len(request.target_file_paths)} file(s)",
|
||||
"status": "processing",
|
||||
"file_count": len(request.target_file_paths),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -311,9 +311,11 @@ async def dropbox_callback(
|
|||
)
|
||||
|
||||
existing_cursor = db_connector.config.get("cursor")
|
||||
existing_folder_cursors = db_connector.config.get("folder_cursors")
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"cursor": existing_cursor,
|
||||
"folder_cursors": existing_folder_cursors,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
|
|
|
|||
|
|
@ -128,9 +128,20 @@ async def get_editor_content(
|
|||
chunk_contents = chunk_contents_result.scalars().all()
|
||||
|
||||
if not chunk_contents:
|
||||
doc_status = document.status or {}
|
||||
state = (
|
||||
doc_status.get("state", "ready")
|
||||
if isinstance(doc_status, dict)
|
||||
else "ready"
|
||||
)
|
||||
if state in ("pending", "processing"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This document is still being processed. Please wait a moment and try again.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This document has no content and cannot be edited. Please re-upload to enable editing.",
|
||||
detail="This document has no viewable content yet. It may still be syncing. Try again in a few seconds, or re-upload if the issue persists.",
|
||||
)
|
||||
|
||||
markdown_content = "\n\n".join(chunk_contents)
|
||||
|
|
@ -138,7 +149,7 @@ async def get_editor_content(
|
|||
if not markdown_content.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This document has empty content and cannot be edited.",
|
||||
detail="This document appears to be empty. Try re-uploading or editing it to add content.",
|
||||
)
|
||||
|
||||
document.source_markdown = markdown_content
|
||||
|
|
|
|||
|
|
@ -192,6 +192,33 @@ async def get_folder_breadcrumb(
|
|||
) from e
|
||||
|
||||
|
||||
@router.patch("/folders/{folder_id}/watched")
|
||||
async def stop_watching_folder(
|
||||
folder_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Clear the watched flag from a folder's metadata."""
|
||||
folder = await session.get(Folder, folder_id)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to update folders in this search space",
|
||||
)
|
||||
|
||||
if folder.folder_metadata and isinstance(folder.folder_metadata, dict):
|
||||
updated = {**folder.folder_metadata, "watched": False}
|
||||
folder.folder_metadata = updated
|
||||
await session.commit()
|
||||
|
||||
return {"message": "Folder watch status updated"}
|
||||
|
||||
|
||||
@router.put("/folders/{folder_id}", response_model=FolderRead)
|
||||
async def update_folder(
|
||||
folder_id: int,
|
||||
|
|
@ -340,7 +367,7 @@ async def delete_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete a folder and cascade-delete subfolders. Documents are async-deleted via Celery."""
|
||||
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
if not folder:
|
||||
|
|
@ -372,30 +399,29 @@ async def delete_folder(
|
|||
)
|
||||
await session.commit()
|
||||
|
||||
await session.execute(Folder.__table__.delete().where(Folder.id == folder_id))
|
||||
await session.commit()
|
||||
try:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
delete_folder_documents_task,
|
||||
)
|
||||
|
||||
if document_ids:
|
||||
try:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
delete_folder_documents_task,
|
||||
)
|
||||
|
||||
delete_folder_documents_task.delay(document_ids)
|
||||
except Exception as err:
|
||||
delete_folder_documents_task.delay(
|
||||
document_ids, folder_subtree_ids=list(subtree_ids)
|
||||
)
|
||||
except Exception as err:
|
||||
if document_ids:
|
||||
await session.execute(
|
||||
Document.__table__.update()
|
||||
.where(Document.id.in_(document_ids))
|
||||
.values(status={"state": "ready"})
|
||||
)
|
||||
await session.commit()
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Folder deleted but document cleanup could not be queued. Documents have been restored.",
|
||||
) from err
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Could not queue folder deletion. Documents have been restored.",
|
||||
) from err
|
||||
|
||||
return {
|
||||
"message": "Folder deleted successfully",
|
||||
"message": "Folder deletion started",
|
||||
"documents_queued_for_deletion": len(document_ids),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -96,9 +100,14 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -146,8 +155,11 @@ async def reauth_calendar(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -225,6 +237,7 @@ async def calendar_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_CALENDAR_REDIRECT_URI:
|
||||
|
|
@ -233,6 +246,7 @@ async def calendar_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
# Relax token scope validation for Google OAuth
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
|
@ -127,14 +131,19 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline", # Get refresh token
|
||||
prompt="consent", # Force consent screen to get refresh token
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
|
@ -193,8 +202,11 @@ async def reauth_drive(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -285,6 +297,7 @@ async def drive_callback(
|
|||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
logger.info(
|
||||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
|
|
@ -296,8 +309,9 @@ async def drive_callback(
|
|||
status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Exchange authorization code for tokens (restore PKCE code_verifier from state)
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -109,9 +113,14 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -164,8 +173,11 @@ async def reauth_gmail(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -256,6 +268,7 @@ async def gmail_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_GMAIL_REDIRECT_URI:
|
||||
|
|
@ -264,6 +277,7 @@ async def gmail_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
API route for fetching the available LLM models catalogue.
|
||||
API route for fetching the available models catalogue.
|
||||
|
||||
Serves a dynamically-updated list sourced from the OpenRouter public API,
|
||||
with a local JSON fallback when the API is unreachable.
|
||||
|
|
@ -30,7 +30,7 @@ async def list_available_models(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Return all available LLM models grouped by provider.
|
||||
Return all available models grouped by provider.
|
||||
|
||||
The list is sourced from the OpenRouter public API and cached for 1 hour.
|
||||
If the API is unreachable, a local fallback file is used instead.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
API routes for NewLLMConfig CRUD operations.
|
||||
|
||||
NewLLMConfig combines LLM model settings with prompt configuration:
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
|
|
|
|||
|
|
@ -55,23 +55,12 @@ from app.schemas import (
|
|||
)
|
||||
from app.services.composio_service import ComposioService, get_composio_service
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.tasks.connector_indexers import (
|
||||
index_airtable_records,
|
||||
index_clickup_tasks,
|
||||
index_confluence_pages,
|
||||
index_crawled_urls,
|
||||
index_discord_messages,
|
||||
index_elasticsearch_documents,
|
||||
index_github_repos,
|
||||
index_google_calendar_events,
|
||||
index_google_gmail_messages,
|
||||
index_jira_issues,
|
||||
index_linear_issues,
|
||||
index_luma_events,
|
||||
index_notion_pages,
|
||||
index_slack_messages,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
# NOTE: connector indexer functions are imported lazily inside each
|
||||
# ``run_*_indexing`` helper to break a circular import cycle:
|
||||
# connector_indexers.__init__ → airtable_indexer → airtable_history
|
||||
# → app.routes.__init__ → this file → connector_indexers (not ready yet)
|
||||
from app.utils.connector_naming import ensure_unique_connector_name
|
||||
from app.utils.indexing_locks import (
|
||||
acquire_connector_indexing_lock,
|
||||
|
|
@ -1378,6 +1367,8 @@ async def run_slack_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_slack_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -1824,6 +1815,8 @@ async def run_notion_indexing_with_new_session(
|
|||
Create a new session and run the Notion indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_notion_pages
|
||||
|
||||
async with async_session_maker() as session:
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
|
|
@ -1858,6 +1851,8 @@ async def run_notion_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_notion_pages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -1910,6 +1905,8 @@ async def run_github_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_github_repos
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -1961,6 +1958,8 @@ async def run_linear_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_linear_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2011,6 +2010,8 @@ async def run_discord_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_discord_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2113,6 +2114,8 @@ async def run_jira_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_jira_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2166,6 +2169,8 @@ async def run_confluence_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_confluence_pages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2217,6 +2222,8 @@ async def run_clickup_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_clickup_tasks
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2268,6 +2275,8 @@ async def run_airtable_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_airtable_records
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2321,6 +2330,8 @@ async def run_google_calendar_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_google_calendar_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2370,6 +2381,7 @@ async def run_google_gmail_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_google_gmail_messages
|
||||
|
||||
# Create a wrapper function that calls index_google_gmail_messages with max_messages
|
||||
async def gmail_indexing_wrapper(
|
||||
|
|
@ -2465,6 +2477,8 @@ async def run_google_drive_indexing(
|
|||
stage="fetching",
|
||||
)
|
||||
|
||||
total_unsupported = 0
|
||||
|
||||
# Index each folder with indexing options
|
||||
for folder in items.folders:
|
||||
try:
|
||||
|
|
@ -2472,6 +2486,7 @@ async def run_google_drive_indexing(
|
|||
indexed_count,
|
||||
skipped_count,
|
||||
error_message,
|
||||
unsupported_count,
|
||||
) = await index_google_drive_files(
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -2485,6 +2500,7 @@ async def run_google_drive_indexing(
|
|||
include_subfolders=indexing_options.include_subfolders,
|
||||
)
|
||||
total_skipped += skipped_count
|
||||
total_unsupported += unsupported_count
|
||||
if error_message:
|
||||
errors.append(f"Folder '{folder.name}': {error_message}")
|
||||
else:
|
||||
|
|
@ -2560,6 +2576,7 @@ async def run_google_drive_indexing(
|
|||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
unsupported_count=total_unsupported,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2630,7 +2647,12 @@ async def run_onedrive_indexing(
|
|||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_onedrive_files(
|
||||
(
|
||||
total_indexed,
|
||||
total_skipped,
|
||||
error_message,
|
||||
total_unsupported,
|
||||
) = await index_onedrive_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
|
|
@ -2671,6 +2693,7 @@ async def run_onedrive_indexing(
|
|||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
unsupported_count=total_unsupported,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2738,7 +2761,12 @@ async def run_dropbox_indexing(
|
|||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_dropbox_files(
|
||||
(
|
||||
total_indexed,
|
||||
total_skipped,
|
||||
error_message,
|
||||
total_unsupported,
|
||||
) = await index_dropbox_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
|
|
@ -2779,6 +2807,7 @@ async def run_dropbox_indexing(
|
|||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
unsupported_count=total_unsupported,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2836,6 +2865,8 @@ async def run_luma_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_luma_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2888,6 +2919,8 @@ async def run_elasticsearch_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_elasticsearch_documents
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2938,6 +2971,8 @@ async def run_web_page_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_crawled_urls
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from app.db import (
|
|||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
User,
|
||||
VisionLLMConfig,
|
||||
get_async_session,
|
||||
get_default_roles_config,
|
||||
)
|
||||
|
|
@ -483,6 +484,63 @@ async def _get_image_gen_config_by_id(
|
|||
return None
|
||||
|
||||
|
||||
async def _get_vision_llm_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available vision LLM providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
return None
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
|
|
@ -522,14 +580,19 @@ async def get_llm_preferences(
|
|||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm_config = await _get_vision_llm_config_by_id(
|
||||
session, search_space.vision_llm_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm_config=vision_llm_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -589,14 +652,19 @@ async def update_llm_preferences(
|
|||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm_config = await _get_vision_llm_config_by_id(
|
||||
session, search_space.vision_llm_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm_config=vision_llm_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
291
surfsense_backend/app/routes/vision_llm_routes.py
Normal file
291
surfsense_backend/app/routes/vision_llm_routes.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Permission,
|
||||
User,
|
||||
VisionLLMConfig,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
from app.services.vision_model_list_service import get_vision_model_list
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision Model Catalogue (from OpenRouter, filtered for image-input models)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class VisionModelListItem(BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
provider: str
|
||||
context_window: str | None = None
|
||||
|
||||
|
||||
@router.get("/vision-models", response_model=list[VisionModelListItem])
|
||||
async def list_vision_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return vision-capable models sourced from OpenRouter (filtered by image input)."""
|
||||
try:
|
||||
return await get_vision_model_list()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch vision model list")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch vision model list: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Vision LLM Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-vision-llm-configs",
|
||||
response_model=list[GlobalVisionLLMConfigRead],
|
||||
)
|
||||
async def get_global_vision_llm_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
global_configs = config.GLOBAL_VISION_LLM_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes across available vision LLM providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global vision LLM configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VisionLLMConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead)
|
||||
async def create_vision_llm_config(
|
||||
config_data: VisionLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead])
|
||||
async def list_vision_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
"You don't have permission to view vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig)
|
||||
.filter(VisionLLMConfig.search_space_id == search_space_id)
|
||||
.order_by(VisionLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list VisionLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
|
||||
async def get_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
"You don't have permission to view vision LLM configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
|
||||
async def update_vision_llm_config(
|
||||
config_id: int,
|
||||
update_data: VisionLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to update vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/vision-llm-configs/{config_id}", response_model=dict)
|
||||
async def delete_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Vision LLM config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
|
@ -125,6 +125,13 @@ from .video_presentations import (
|
|||
VideoPresentationRead,
|
||||
VideoPresentationUpdate,
|
||||
)
|
||||
from .vision_llm import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigPublic,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Folder schemas
|
||||
|
|
@ -163,6 +170,8 @@ __all__ = [
|
|||
"FolderUpdate",
|
||||
"GlobalImageGenConfigRead",
|
||||
"GlobalNewLLMConfigRead",
|
||||
# Vision LLM Config schemas
|
||||
"GlobalVisionLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
# Base schemas
|
||||
|
|
@ -264,4 +273,8 @@ __all__ = [
|
|||
"VideoPresentationCreate",
|
||||
"VideoPresentationRead",
|
||||
"VideoPresentationUpdate",
|
||||
"VisionLLMConfigCreate",
|
||||
"VisionLLMConfigPublic",
|
||||
"VisionLLMConfigRead",
|
||||
"VisionLLMConfigUpdate",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Pydantic schemas for folder CRUD, move, and reorder operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
|
@ -34,6 +35,9 @@ class FolderRead(BaseModel):
|
|||
created_by_id: UUID | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None, validation_alias="folder_metadata"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Pydantic schemas for the NewLLMConfig API.
|
||||
|
||||
NewLLMConfig combines LLM model settings with prompt configuration:
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
|
|
@ -26,7 +26,7 @@ class NewLLMConfigBase(BaseModel):
|
|||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
|
||||
# LLM Model Configuration
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
|
|
@ -71,7 +71,7 @@ class NewLLMConfigUpdate(BaseModel):
|
|||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
|
||||
# LLM Model Configuration
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
|
|
@ -106,7 +106,7 @@ class NewLLMConfigPublic(BaseModel):
|
|||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# LLM Model Configuration (no api_key)
|
||||
# Model Configuration (no api_key)
|
||||
provider: LiteLLMProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
|
|
@ -149,7 +149,7 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# LLM Model Configuration (no api_key)
|
||||
# Model Configuration (no api_key)
|
||||
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
|
|
@ -182,6 +182,10 @@ class LLMPreferencesRead(BaseModel):
|
|||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_config_id: int | None = Field(
|
||||
None,
|
||||
description="ID of the vision LLM config to use for vision/screenshot analysis",
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for agent LLM"
|
||||
)
|
||||
|
|
@ -191,6 +195,9 @@ class LLMPreferencesRead(BaseModel):
|
|||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
vision_llm_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for vision LLM"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -207,3 +214,7 @@ class LLMPreferencesUpdate(BaseModel):
|
|||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_config_id: int | None = Field(
|
||||
None,
|
||||
description="ID of the vision LLM config to use for vision/screenshot analysis",
|
||||
)
|
||||
|
|
|
|||
75
surfsense_backend/app/schemas/vision_llm.py
Normal file
75
surfsense_backend/app/schemas/vision_llm.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import VisionProvider
|
||||
|
||||
|
||||
class VisionLLMConfigBase(BaseModel):
|
||||
name: str = Field(..., max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: VisionProvider = Field(...)
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str = Field(..., max_length=100)
|
||||
api_key: str = Field(...)
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class VisionLLMConfigCreate(VisionLLMConfigBase):
|
||||
search_space_id: int = Field(...)
|
||||
|
||||
|
||||
class VisionLLMConfigUpdate(BaseModel):
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: VisionProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class VisionLLMConfigRead(VisionLLMConfigBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class VisionLLMConfigPublic(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: VisionProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GlobalVisionLLMConfigRead(BaseModel):
|
||||
id: int = Field(...)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
|
|
@ -111,9 +111,8 @@ class DoclingService:
|
|||
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
|
||||
)
|
||||
|
||||
# Initialize DocumentConverter
|
||||
self.converter = DocumentConverter(
|
||||
format_options={InputFormat.PDF: pdf_format_option}
|
||||
format_options={InputFormat.PDF: pdf_format_option},
|
||||
)
|
||||
|
||||
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
|
||||
|
|
|
|||
|
|
@ -405,6 +405,121 @@ async def get_document_summary_llm(
|
|||
)
|
||||
|
||||
|
||||
async def get_vision_llm(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's vision LLM instance for screenshot analysis.
|
||||
|
||||
Resolves from the dedicated VisionLLMConfig system:
|
||||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
get_global_vision_llm_config,
|
||||
is_vision_auto_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
logger.error(f"Search space {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = search_space.vision_llm_config_id
|
||||
if config_id is None:
|
||||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
"Vision Auto mode requested but Vision LLM Router not initialized"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
if config_id < 0:
|
||||
global_cfg = get_global_vision_llm_config(config_id)
|
||||
if not global_cfg:
|
||||
logger.error(f"Global vision LLM config {config_id} not found")
|
||||
return None
|
||||
|
||||
if global_cfg.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
|
||||
)
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
global_cfg["provider"].upper(),
|
||||
global_cfg["provider"].lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{global_cfg['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_cfg["api_key"],
|
||||
}
|
||||
if global_cfg.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_cfg["api_base"]
|
||||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
VisionLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
vision_cfg = result.scalars().first()
|
||||
if not vision_cfg:
|
||||
logger.error(
|
||||
f"Vision LLM config {config_id} not found in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
if vision_cfg.custom_provider:
|
||||
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
vision_cfg.provider.value.upper(),
|
||||
vision_cfg.provider.value.lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{vision_cfg.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": vision_cfg.api_key,
|
||||
}
|
||||
if vision_cfg.api_base:
|
||||
litellm_kwargs["api_base"] = vision_cfg.api_base
|
||||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting vision LLM for search space {search_space_id}: {e!s}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Service for fetching and caching the available LLM model list.
|
||||
Service for fetching and caching the available model list.
|
||||
|
||||
Uses the OpenRouter public API as the primary source, with a local
|
||||
fallback JSON file when the API is unreachable.
|
||||
|
|
|
|||
|
|
@ -421,6 +421,7 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
|||
error_message: str | None = None,
|
||||
is_warning: bool = False,
|
||||
skipped_count: int | None = None,
|
||||
unsupported_count: int | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update notification when connector indexing completes.
|
||||
|
|
@ -428,10 +429,11 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
|||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
indexed_count: Total number of items indexed
|
||||
indexed_count: Total number of files indexed
|
||||
error_message: Error message if indexing failed, or warning message (optional)
|
||||
is_warning: If True, treat error_message as a warning (success case) rather than an error
|
||||
skipped_count: Number of items skipped (e.g., duplicates) - optional
|
||||
skipped_count: Number of files skipped (e.g., unchanged) - optional
|
||||
unsupported_count: Number of files skipped because the ETL parser doesn't support them
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
|
|
@ -440,52 +442,45 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
|||
"connector_name", "Connector"
|
||||
)
|
||||
|
||||
# Build the skipped text if there are skipped items
|
||||
skipped_text = ""
|
||||
if skipped_count and skipped_count > 0:
|
||||
skipped_item_text = "item" if skipped_count == 1 else "items"
|
||||
skipped_text = (
|
||||
f" ({skipped_count} {skipped_item_text} skipped - already indexed)"
|
||||
)
|
||||
unsupported_text = ""
|
||||
if unsupported_count and unsupported_count > 0:
|
||||
file_word = "file was" if unsupported_count == 1 else "files were"
|
||||
unsupported_text = f" {unsupported_count} {file_word} not supported."
|
||||
|
||||
# If there's an error message but items were indexed, treat it as a warning (partial success)
|
||||
# If is_warning is True, treat it as success even with 0 items (e.g., duplicates found)
|
||||
# Otherwise, treat it as a failure
|
||||
if error_message:
|
||||
if indexed_count > 0:
|
||||
# Partial success with warnings (e.g., duplicate content from other connectors)
|
||||
title = f"Ready: {connector_name}"
|
||||
item_text = "item" if indexed_count == 1 else "items"
|
||||
message = f"Now searchable! {indexed_count} {item_text} synced{skipped_text}. Note: {error_message}"
|
||||
file_text = "file" if indexed_count == 1 else "files"
|
||||
message = f"Now searchable! {indexed_count} {file_text} synced.{unsupported_text} Note: {error_message}"
|
||||
status = "completed"
|
||||
elif is_warning:
|
||||
# Warning case (e.g., duplicates found) - treat as success
|
||||
title = f"Ready: {connector_name}"
|
||||
message = f"Sync completed{skipped_text}. {error_message}"
|
||||
message = f"Sync complete.{unsupported_text} {error_message}"
|
||||
status = "completed"
|
||||
else:
|
||||
# Complete failure
|
||||
title = f"Failed: {connector_name}"
|
||||
message = f"Sync failed: {error_message}"
|
||||
if unsupported_text:
|
||||
message += unsupported_text
|
||||
status = "failed"
|
||||
else:
|
||||
title = f"Ready: {connector_name}"
|
||||
if indexed_count == 0:
|
||||
if skipped_count and skipped_count > 0:
|
||||
skipped_item_text = "item" if skipped_count == 1 else "items"
|
||||
message = f"Already up to date! {skipped_count} {skipped_item_text} skipped (already indexed)."
|
||||
if unsupported_count and unsupported_count > 0:
|
||||
message = f"Sync complete.{unsupported_text}"
|
||||
else:
|
||||
message = "Already up to date! No new items to sync."
|
||||
message = "Already up to date!"
|
||||
else:
|
||||
item_text = "item" if indexed_count == 1 else "items"
|
||||
message = (
|
||||
f"Now searchable! {indexed_count} {item_text} synced{skipped_text}."
|
||||
)
|
||||
file_text = "file" if indexed_count == 1 else "files"
|
||||
message = f"Now searchable! {indexed_count} {file_text} synced."
|
||||
if unsupported_text:
|
||||
message += unsupported_text
|
||||
status = "completed"
|
||||
|
||||
metadata_updates = {
|
||||
"indexed_count": indexed_count,
|
||||
"skipped_count": skipped_count or 0,
|
||||
"unsupported_count": unsupported_count or 0,
|
||||
"sync_stage": "completed"
|
||||
if (not error_message or is_warning or indexed_count > 0)
|
||||
else "failed",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Service for managing user page limits for ETL services.
|
|||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -223,10 +223,155 @@ class PageLimitService:
|
|||
# Estimate ~2000 characters per page
|
||||
return max(1, content_length // 2000)
|
||||
|
||||
@staticmethod
|
||||
def estimate_pages_from_metadata(
|
||||
file_name_or_ext: str, file_size: int | str | None = None
|
||||
) -> int:
|
||||
"""Size-based page estimation from file name/extension and byte size.
|
||||
|
||||
Pure function — no file I/O, no database access. Used by cloud
|
||||
connectors (which only have API metadata) and as the internal
|
||||
fallback for :meth:`estimate_pages_before_processing`.
|
||||
|
||||
``file_name_or_ext`` can be a full filename (``"report.pdf"``) or
|
||||
a bare extension (``".pdf"``). ``file_size`` may be an int, a
|
||||
stringified int from a cloud API, or *None*.
|
||||
"""
|
||||
if file_size is not None:
|
||||
try:
|
||||
file_size = int(file_size)
|
||||
except (ValueError, TypeError):
|
||||
file_size = 0
|
||||
else:
|
||||
file_size = 0
|
||||
|
||||
if file_size <= 0:
|
||||
return 1
|
||||
|
||||
ext = PurePosixPath(file_name_or_ext).suffix.lower() if file_name_or_ext else ""
|
||||
if not ext and file_name_or_ext.startswith("."):
|
||||
ext = file_name_or_ext.lower()
|
||||
file_ext = ext
|
||||
|
||||
if file_ext == ".pdf":
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".doc",
|
||||
".docx",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm",
|
||||
".odt",
|
||||
".ott",
|
||||
".sxw",
|
||||
".stw",
|
||||
".uot",
|
||||
".rtf",
|
||||
".pages",
|
||||
".wpd",
|
||||
".wps",
|
||||
".abw",
|
||||
".zabw",
|
||||
".cwk",
|
||||
".hwp",
|
||||
".lwp",
|
||||
".mcw",
|
||||
".mw",
|
||||
".sdw",
|
||||
".vor",
|
||||
}:
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".ppt",
|
||||
".pptx",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx",
|
||||
".odp",
|
||||
".otp",
|
||||
".sxi",
|
||||
".sti",
|
||||
".uop",
|
||||
".key",
|
||||
".sda",
|
||||
".sdd",
|
||||
".sdp",
|
||||
}:
|
||||
return max(1, file_size // (200 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".xls",
|
||||
".xlsx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".xlr",
|
||||
".ods",
|
||||
".ots",
|
||||
".fods",
|
||||
".numbers",
|
||||
".123",
|
||||
".wk1",
|
||||
".wk2",
|
||||
".wk3",
|
||||
".wk4",
|
||||
".wks",
|
||||
".wb1",
|
||||
".wb2",
|
||||
".wb3",
|
||||
".wq1",
|
||||
".wq2",
|
||||
".csv",
|
||||
".tsv",
|
||||
".slk",
|
||||
".sylk",
|
||||
".dif",
|
||||
".dbf",
|
||||
".prn",
|
||||
".qpw",
|
||||
".602",
|
||||
".et",
|
||||
".eth",
|
||||
}:
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
if file_ext in {".epub"}:
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
if file_ext in {".txt", ".log", ".md", ".markdown", ".htm", ".html", ".xml"}:
|
||||
return max(1, file_size // 3000)
|
||||
|
||||
if file_ext in {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".webp",
|
||||
".svg",
|
||||
".cgm",
|
||||
".odg",
|
||||
".pbd",
|
||||
}:
|
||||
return 1
|
||||
|
||||
if file_ext in {".mp3", ".m4a", ".wav", ".mpga"}:
|
||||
return max(1, file_size // (1024 * 1024))
|
||||
|
||||
if file_ext in {".mp4", ".mpeg", ".webm"}:
|
||||
return max(1, file_size // (5 * 1024 * 1024))
|
||||
|
||||
return max(1, file_size // (80 * 1024))
|
||||
|
||||
def estimate_pages_before_processing(self, file_path: str) -> int:
|
||||
"""
|
||||
Estimate page count from file before processing (to avoid unnecessary API calls).
|
||||
This is called BEFORE sending to ETL services to prevent cost on rejected files.
|
||||
Estimate page count from a local file before processing.
|
||||
|
||||
For PDFs, attempts to read the actual page count via pypdf.
|
||||
For everything else, delegates to :meth:`estimate_pages_from_metadata`.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
|
@ -240,7 +385,6 @@ class PageLimitService:
|
|||
file_ext = Path(file_path).suffix.lower()
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
# PDF files - try to get actual page count
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import pypdf
|
||||
|
|
@ -249,153 +393,6 @@ class PageLimitService:
|
|||
pdf_reader = pypdf.PdfReader(f)
|
||||
return len(pdf_reader.pages)
|
||||
except Exception:
|
||||
# If PDF reading fails, fall back to size estimation
|
||||
# Typical PDF: ~100KB per page (conservative estimate)
|
||||
return max(1, file_size // (100 * 1024))
|
||||
pass # fall through to size-based estimation
|
||||
|
||||
# Word Processing Documents
|
||||
# Microsoft Word, LibreOffice Writer, WordPerfect, Pages, etc.
|
||||
elif file_ext in [
|
||||
".doc",
|
||||
".docx",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm", # Microsoft Word
|
||||
".odt",
|
||||
".ott",
|
||||
".sxw",
|
||||
".stw",
|
||||
".uot", # OpenDocument/StarOffice Writer
|
||||
".rtf", # Rich Text Format
|
||||
".pages", # Apple Pages
|
||||
".wpd",
|
||||
".wps", # WordPerfect, Microsoft Works
|
||||
".abw",
|
||||
".zabw", # AbiWord
|
||||
".cwk",
|
||||
".hwp",
|
||||
".lwp",
|
||||
".mcw",
|
||||
".mw",
|
||||
".sdw",
|
||||
".vor", # Other word processors
|
||||
]:
|
||||
# Typical word document: ~50KB per page (conservative)
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
# Presentation Documents
|
||||
# PowerPoint, Impress, Keynote, etc.
|
||||
elif file_ext in [
|
||||
".ppt",
|
||||
".pptx",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx", # Microsoft PowerPoint
|
||||
".odp",
|
||||
".otp",
|
||||
".sxi",
|
||||
".sti",
|
||||
".uop", # OpenDocument/StarOffice Impress
|
||||
".key", # Apple Keynote
|
||||
".sda",
|
||||
".sdd",
|
||||
".sdp", # StarOffice Draw/Impress
|
||||
]:
|
||||
# Typical presentation: ~200KB per slide (conservative)
|
||||
return max(1, file_size // (200 * 1024))
|
||||
|
||||
# Spreadsheet Documents
|
||||
# Excel, Calc, Numbers, Lotus, etc.
|
||||
elif file_ext in [
|
||||
".xls",
|
||||
".xlsx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".xlr", # Microsoft Excel
|
||||
".ods",
|
||||
".ots",
|
||||
".fods", # OpenDocument Spreadsheet
|
||||
".numbers", # Apple Numbers
|
||||
".123",
|
||||
".wk1",
|
||||
".wk2",
|
||||
".wk3",
|
||||
".wk4",
|
||||
".wks", # Lotus 1-2-3
|
||||
".wb1",
|
||||
".wb2",
|
||||
".wb3",
|
||||
".wq1",
|
||||
".wq2", # Quattro Pro
|
||||
".csv",
|
||||
".tsv",
|
||||
".slk",
|
||||
".sylk",
|
||||
".dif",
|
||||
".dbf",
|
||||
".prn",
|
||||
".qpw", # Data formats
|
||||
".602",
|
||||
".et",
|
||||
".eth", # Other spreadsheets
|
||||
]:
|
||||
# Spreadsheets typically have 1 sheet = 1 page for ETL
|
||||
# Conservative: ~100KB per sheet
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
# E-books
|
||||
elif file_ext in [".epub"]:
|
||||
# E-books vary widely, estimate by size
|
||||
# Typical e-book: ~50KB per page
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
# Plain Text and Markup Files
|
||||
elif file_ext in [
|
||||
".txt",
|
||||
".log", # Plain text
|
||||
".md",
|
||||
".markdown", # Markdown
|
||||
".htm",
|
||||
".html",
|
||||
".xml", # Markup
|
||||
]:
|
||||
# Plain text: ~3000 bytes per page
|
||||
return max(1, file_size // 3000)
|
||||
|
||||
# Image Files
|
||||
# Each image is typically processed as 1 page
|
||||
elif file_ext in [
|
||||
".jpg",
|
||||
".jpeg", # JPEG
|
||||
".png", # PNG
|
||||
".gif", # GIF
|
||||
".bmp", # Bitmap
|
||||
".tiff", # TIFF
|
||||
".webp", # WebP
|
||||
".svg", # SVG
|
||||
".cgm", # Computer Graphics Metafile
|
||||
".odg",
|
||||
".pbd", # OpenDocument Graphics
|
||||
]:
|
||||
# Each image = 1 page
|
||||
return 1
|
||||
|
||||
# Audio Files (transcription = typically 1 page per minute)
|
||||
# Note: These should be handled by audio transcription flow, not ETL
|
||||
elif file_ext in [".mp3", ".m4a", ".wav", ".mpga"]:
|
||||
# Audio files: estimate based on duration
|
||||
# Fallback: ~1MB per minute of audio, 1 page per minute transcript
|
||||
return max(1, file_size // (1024 * 1024))
|
||||
|
||||
# Video Files (typically not processed for pages, but just in case)
|
||||
elif file_ext in [".mp4", ".mpeg", ".webm"]:
|
||||
# Video files: very rough estimate
|
||||
# Typically wouldn't be page-based, but use conservative estimate
|
||||
return max(1, file_size // (5 * 1024 * 1024))
|
||||
|
||||
# Other/Unknown Document Types
|
||||
else:
|
||||
# Conservative estimate: ~80KB per page
|
||||
# This catches: .sgl, .sxg, .uof, .uos1, .uos2, .web, and any future formats
|
||||
return max(1, file_size // (80 * 1024))
|
||||
return self.estimate_pages_from_metadata(file_ext, file_size)
|
||||
|
|
|
|||
158
surfsense_backend/app/services/vision_autocomplete_service.py
Normal file
158
surfsense_backend/app/services/vision_autocomplete_service.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Vision autocomplete service — agent-based with scoped filesystem.
|
||||
|
||||
Optimized pipeline:
|
||||
1. Start the SSE stream immediately so the UI shows progress.
|
||||
2. Derive a KB search query from window_title (no separate LLM call).
|
||||
3. Run KB filesystem pre-computation and agent graph compilation in PARALLEL.
|
||||
4. Inject pre-computed KB files as initial state and stream the agent.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.autocomplete import create_autocomplete_agent, stream_autocomplete_agent
|
||||
from app.services.llm_service import get_vision_llm
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PREP_STEP_ID = "autocomplete-prep"
|
||||
|
||||
|
||||
def _derive_kb_query(app_name: str, window_title: str) -> str:
|
||||
parts = [p for p in (window_title, app_name) if p]
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _is_vision_unsupported_error(e: Exception) -> bool:
|
||||
msg = str(e).lower()
|
||||
return "content must be a string" in msg or "does not support image" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def stream_vision_autocomplete(
|
||||
screenshot_data_url: str,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Analyze a screenshot with a vision-LLM agent and stream a text completion."""
|
||||
streaming = VercelStreamingService()
|
||||
vision_error_msg = (
|
||||
"The selected model does not support vision. "
|
||||
"Please set a vision-capable model (e.g. GPT-4o, Gemini) in your search space settings."
|
||||
)
|
||||
|
||||
llm = await get_vision_llm(session, search_space_id)
|
||||
if not llm:
|
||||
yield streaming.format_message_start()
|
||||
yield streaming.format_error("No Vision LLM configured for this search space")
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
# Start SSE stream immediately so the UI has something to show
|
||||
yield streaming.format_message_start()
|
||||
|
||||
kb_query = _derive_kb_query(app_name, window_title)
|
||||
|
||||
# Show a preparation step while KB search + agent compile run
|
||||
yield streaming.format_thinking_step(
|
||||
step_id=PREP_STEP_ID,
|
||||
title="Searching knowledge base",
|
||||
status="in_progress",
|
||||
items=[kb_query] if kb_query else [],
|
||||
)
|
||||
|
||||
try:
|
||||
agent, kb = await create_autocomplete_agent(
|
||||
llm,
|
||||
search_space_id=search_space_id,
|
||||
kb_query=kb_query,
|
||||
app_name=app_name,
|
||||
window_title=window_title,
|
||||
)
|
||||
except Exception as e:
|
||||
if _is_vision_unsupported_error(e):
|
||||
logger.warning("Vision autocomplete: model does not support vision: %s", e)
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
yield streaming.format_done()
|
||||
return
|
||||
logger.error("Failed to create autocomplete agent: %s", e, exc_info=True)
|
||||
yield streaming.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
has_kb = kb.has_documents
|
||||
doc_count = len(kb.files) if has_kb else 0 # type: ignore[arg-type]
|
||||
|
||||
yield streaming.format_thinking_step(
|
||||
step_id=PREP_STEP_ID,
|
||||
title="Searching knowledge base",
|
||||
status="complete",
|
||||
items=[f"Found {doc_count} document{'s' if doc_count != 1 else ''}"]
|
||||
if kb_query
|
||||
else ["Skipped"],
|
||||
)
|
||||
|
||||
# Build agent input with pre-computed KB as initial state
|
||||
if has_kb:
|
||||
instruction = (
|
||||
"Analyze this screenshot, then explore the knowledge base documents "
|
||||
"listed above — read the chunk index of any document whose title "
|
||||
"looks relevant and check matched chunks for useful facts. "
|
||||
"Finally, generate a concise autocomplete for the active text area, "
|
||||
"enhanced with any relevant KB information you found."
|
||||
)
|
||||
else:
|
||||
instruction = (
|
||||
"Analyze this screenshot and generate a concise autocomplete "
|
||||
"for the active text area based on what you see."
|
||||
)
|
||||
|
||||
user_message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": instruction},
|
||||
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
|
||||
]
|
||||
)
|
||||
|
||||
input_data: dict = {"messages": [user_message]}
|
||||
|
||||
if has_kb:
|
||||
input_data["files"] = kb.files
|
||||
input_data["messages"] = [kb.ls_ai_msg, kb.ls_tool_msg, user_message]
|
||||
logger.info(
|
||||
"Autocomplete: injected %d KB files into agent initial state", doc_count
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Autocomplete: no KB documents found, proceeding with screenshot only"
|
||||
)
|
||||
|
||||
# Stream the agent (message_start already sent above)
|
||||
try:
|
||||
async for sse in stream_autocomplete_agent(
|
||||
agent,
|
||||
input_data,
|
||||
streaming,
|
||||
emit_message_start=False,
|
||||
):
|
||||
yield sse
|
||||
except Exception as e:
|
||||
if _is_vision_unsupported_error(e):
|
||||
logger.warning("Vision autocomplete: model does not support vision: %s", e)
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
yield streaming.format_done()
|
||||
else:
|
||||
logger.error("Vision autocomplete streaming error: %s", e, exc_info=True)
|
||||
yield streaming.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming.format_done()
|
||||
193
surfsense_backend/app/services/vision_llm_router_service.py
Normal file
193
surfsense_backend/app/services/vision_llm_router_service.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from litellm import Router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
||||
VISION_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GOOGLE": "gemini",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"XAI": "xai",
|
||||
"OPENROUTER": "openrouter",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"GROQ": "groq",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"MISTRAL": "mistral",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
class VisionLLMRouterService:
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "VisionLLMRouterService":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
instance = cls.get_instance()
|
||||
|
||||
if instance._initialized:
|
||||
logger.debug("Vision LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
||||
if not model_list:
|
||||
logger.warning(
|
||||
"No valid vision LLM configs found for router initialization"
|
||||
)
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._router_settings = router_settings or {}
|
||||
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
"retry_after": 5,
|
||||
}
|
||||
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False,
|
||||
)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
"Vision LLM Router initialized with %d deployments, strategy: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Vision LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
try:
|
||||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
|
||||
deployment: dict[str, Any] = {
|
||||
"model_name": "auto",
|
||||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
if config.get("rpm"):
|
||||
deployment["rpm"] = config["rpm"]
|
||||
if config.get("tpm"):
|
||||
deployment["tpm"] = config["tpm"]
|
||||
|
||||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert vision config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_router(cls) -> Router | None:
|
||||
instance = cls.get_instance()
|
||||
return instance._router
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
instance = cls.get_instance()
|
||||
return instance._initialized and instance._router is not None
|
||||
|
||||
@classmethod
|
||||
def get_model_count(cls) -> int:
|
||||
instance = cls.get_instance()
|
||||
return len(instance._model_list)
|
||||
|
||||
|
||||
def is_vision_auto_mode(config_id: int | None) -> bool:
|
||||
return config_id == VISION_AUTO_MODE_ID
|
||||
|
||||
|
||||
def build_vision_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def get_global_vision_llm_config(config_id: int) -> dict | None:
|
||||
from app.config import config
|
||||
|
||||
if config_id == VISION_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": VISION_AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
134
surfsense_backend/app/services/vision_model_list_service.py
Normal file
134
surfsense_backend/app/services/vision_model_list_service.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Service for fetching and caching the vision-capable model list.
|
||||
|
||||
Reuses the same OpenRouter public API and local fallback as the LLM model
|
||||
list service, but filters for models that accept image input.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
FALLBACK_FILE = (
|
||||
Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json"
|
||||
)
|
||||
CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
_cache: list[dict] | None = None
|
||||
_cache_timestamp: float = 0
|
||||
|
||||
OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = {
|
||||
"openai": "OPENAI",
|
||||
"anthropic": "ANTHROPIC",
|
||||
"google": "GOOGLE",
|
||||
"mistralai": "MISTRAL",
|
||||
"x-ai": "XAI",
|
||||
}
|
||||
|
||||
|
||||
def _format_context_length(length: int | None) -> str | None:
|
||||
if not length:
|
||||
return None
|
||||
if length >= 1_000_000:
|
||||
return f"{length / 1_000_000:g}M"
|
||||
if length >= 1_000:
|
||||
return f"{length / 1_000:g}K"
|
||||
return str(length)
|
||||
|
||||
|
||||
async def _fetch_from_openrouter() -> list[dict] | None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
response = await client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _load_fallback() -> list[dict]:
|
||||
try:
|
||||
with open(FALLBACK_FILE, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load vision model fallback list: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _is_vision_model(model: dict) -> bool:
|
||||
"""Return True if the model accepts image input and outputs text."""
|
||||
arch = model.get("architecture", {})
|
||||
input_mods = arch.get("input_modalities", [])
|
||||
output_mods = arch.get("output_modalities", [])
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _process_vision_models(raw_models: list[dict]) -> list[dict]:
|
||||
processed: list[dict] = []
|
||||
|
||||
for model in raw_models:
|
||||
model_id: str = model.get("id", "")
|
||||
name: str = model.get("name", "")
|
||||
context_length = model.get("context_length")
|
||||
|
||||
if "/" not in model_id:
|
||||
continue
|
||||
|
||||
if not _is_vision_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": name,
|
||||
"provider": "OPENROUTER",
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
|
||||
if native_provider:
|
||||
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": native_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
async def get_vision_model_list() -> list[dict]:
|
||||
global _cache, _cache_timestamp
|
||||
|
||||
if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS:
|
||||
return _cache
|
||||
|
||||
raw_models = await _fetch_from_openrouter()
|
||||
|
||||
if raw_models is None:
|
||||
logger.info("Using fallback vision model list")
|
||||
return _load_fallback()
|
||||
|
||||
processed = _process_vision_models(raw_models)
|
||||
|
||||
_cache = processed
|
||||
_cache_timestamp = time.time()
|
||||
|
||||
return processed
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""Celery tasks for document processing."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
|
@ -10,6 +11,7 @@ from app.config import config
|
|||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.connector_indexers.local_folder_indexer import index_local_folder
|
||||
from app.tasks.document_processors import (
|
||||
add_extension_received_document,
|
||||
add_youtube_video_document,
|
||||
|
|
@ -141,21 +143,30 @@ async def _delete_document_background(document_id: int) -> None:
|
|||
retry_backoff_max=300,
|
||||
max_retries=5,
|
||||
)
|
||||
def delete_folder_documents_task(self, document_ids: list[int]):
|
||||
"""Celery task to batch-delete documents orphaned by folder deletion."""
|
||||
def delete_folder_documents_task(
|
||||
self,
|
||||
document_ids: list[int],
|
||||
folder_subtree_ids: list[int] | None = None,
|
||||
):
|
||||
"""Celery task to delete documents first, then the folder rows."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_folder_documents(document_ids))
|
||||
loop.run_until_complete(
|
||||
_delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _delete_folder_documents(document_ids: list[int]) -> None:
|
||||
"""Delete chunks in batches, then document rows for each orphaned document."""
|
||||
async def _delete_folder_documents(
|
||||
document_ids: list[int],
|
||||
folder_subtree_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Delete chunks in batches, then document rows, then folder rows."""
|
||||
from sqlalchemy import delete as sa_delete, select
|
||||
|
||||
from app.db import Chunk, Document
|
||||
from app.db import Chunk, Document, Folder
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
batch_size = 500
|
||||
|
|
@ -177,6 +188,12 @@ async def _delete_folder_documents(document_ids: list[int]) -> None:
|
|||
await session.delete(doc)
|
||||
await session.commit()
|
||||
|
||||
if folder_subtree_ids:
|
||||
await session.execute(
|
||||
sa_delete(Folder).where(Folder.id.in_(folder_subtree_ids))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="delete_search_space_background",
|
||||
|
|
@ -1243,3 +1260,154 @@ async def _process_circleback_meeting(
|
|||
heartbeat_task.cancel()
|
||||
if notification:
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
|
||||
# ===== Local folder indexing task =====
|
||||
|
||||
|
||||
@celery_app.task(name="index_local_folder", bind=True)
|
||||
def index_local_folder_task(
|
||||
self,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_path: str,
|
||||
folder_name: str,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
file_extensions: list[str] | None = None,
|
||||
root_folder_id: int | None = None,
|
||||
enable_summary: bool = False,
|
||||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_local_folder_async(
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_path: str,
|
||||
folder_name: str,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
file_extensions: list[str] | None = None,
|
||||
root_folder_id: int | None = None,
|
||||
enable_summary: bool = False,
|
||||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Run local folder indexing with notification + heartbeat."""
|
||||
is_batch = bool(target_file_paths)
|
||||
is_full_scan = not target_file_paths
|
||||
file_count = len(target_file_paths) if target_file_paths else None
|
||||
|
||||
if is_batch:
|
||||
doc_name = f"{folder_name} ({file_count} file{'s' if file_count != 1 else ''})"
|
||||
else:
|
||||
doc_name = folder_name
|
||||
|
||||
notification = None
|
||||
notification_id: int | None = None
|
||||
heartbeat_task = None
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
try:
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="LOCAL_FOLDER_FILE",
|
||||
document_name=doc_name,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
)
|
||||
notification_id = notification.id
|
||||
_start_heartbeat(notification_id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create notification for local folder indexing",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _heartbeat_progress(completed_count: int) -> None:
|
||||
"""Refresh heartbeat and optionally update notification progress."""
|
||||
if notification:
|
||||
with contextlib.suppress(Exception):
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
stage="indexing",
|
||||
stage_message=f"Syncing files ({completed_count}/{file_count or '?'})",
|
||||
)
|
||||
|
||||
try:
|
||||
_indexed, _skipped_or_failed, _rfid, err = await index_local_folder(
|
||||
session=session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
on_heartbeat_callback=_heartbeat_progress
|
||||
if (is_batch or is_full_scan)
|
||||
else None,
|
||||
)
|
||||
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
if err:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=err,
|
||||
)
|
||||
else:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to update notification after local folder indexing",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Local folder indexing failed: {e}")
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:200],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
finally:
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
if notification_id is not None:
|
||||
_stop_heartbeat(notification_id)
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ from .jira_indexer import index_jira_issues
|
|||
|
||||
# Issue tracking and project management
|
||||
from .linear_indexer import index_linear_issues
|
||||
from .luma_indexer import index_luma_events
|
||||
|
||||
# Documentation and knowledge management
|
||||
from .luma_indexer import index_luma_events
|
||||
from .notion_indexer import index_notion_pages
|
||||
from .obsidian_indexer import index_obsidian_vault
|
||||
from .slack_indexer import index_slack_messages
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -50,7 +51,10 @@ async def _should_skip_file(
|
|||
file_id = file.get("id", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
skip, unsup_ext = skip_item(file)
|
||||
if skip:
|
||||
if unsup_ext:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
return True, "folder/non-downloadable"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
|
@ -250,6 +254,121 @@ async def _download_and_index(
|
|||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int):
|
||||
"""Remove a document that was deleted in Dropbox."""
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.DROPBOX_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
cast(Document.document_metadata["dropbox_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
cursor: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, int, str]:
|
||||
"""Delta sync using Dropbox cursor-based change tracking.
|
||||
|
||||
Returns (indexed_count, skipped_count, new_cursor).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting delta sync from cursor: {cursor[:20]}...",
|
||||
{"stage": "delta_sync", "cursor_prefix": cursor[:20]},
|
||||
)
|
||||
|
||||
entries, new_cursor, error = await dropbox_client.get_changes(cursor)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(
|
||||
f"Dropbox authentication failed. Please re-authenticate. (Error: {error})"
|
||||
)
|
||||
raise Exception(f"Failed to fetch Dropbox changes: {error}")
|
||||
|
||||
if not entries:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0, 0, new_cursor or cursor
|
||||
|
||||
logger.info(f"Processing {len(entries)} change entries")
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
for entry in entries:
|
||||
if files_processed >= max_files:
|
||||
break
|
||||
files_processed += 1
|
||||
|
||||
tag = entry.get(".tag")
|
||||
|
||||
if tag == "deleted":
|
||||
path_lower = entry.get("path_lower", "")
|
||||
name = entry.get("name", "")
|
||||
file_id = entry.get("id", "")
|
||||
if file_id:
|
||||
await _remove_document(session, file_id, search_space_id)
|
||||
logger.debug(f"Processed deletion: {name or path_lower}")
|
||||
continue
|
||||
|
||||
if tag != "file":
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, entry, search_space_id)
|
||||
if skip:
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(entry)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
dropbox_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped, unsupported_count, new_cursor or cursor
|
||||
|
||||
|
||||
async def _index_full_scan(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
|
|
@ -265,8 +384,11 @@ async def _index_full_scan(
|
|||
incremental_sync: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Full scan indexing of a folder.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
|
|
@ -278,8 +400,15 @@ async def _index_full_scan(
|
|||
},
|
||||
)
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
|
|
@ -299,14 +428,36 @@ async def _index_full_scan(
|
|||
if incremental_sync:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
elif skip_item(file):
|
||||
else:
|
||||
item_skip, item_unsup = skip_item(file)
|
||||
if item_skip:
|
||||
if item_unsup:
|
||||
unsupported_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Dropbox full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -320,11 +471,20 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
async def _index_selected_files(
|
||||
|
|
@ -338,12 +498,18 @@ async def _index_selected_files(
|
|||
enable_summary: bool,
|
||||
incremental_sync: bool = True,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
|
||||
for file_path, file_name in file_paths:
|
||||
file, error = await get_file_by_path(dropbox_client, file_path)
|
||||
|
|
@ -355,15 +521,31 @@ async def _index_selected_files(
|
|||
if incremental_sync:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
elif skip_item(file):
|
||||
skipped += 1
|
||||
else:
|
||||
item_skip, item_unsup = skip_item(file)
|
||||
if item_skip:
|
||||
if item_unsup:
|
||||
unsupported_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_path
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
|
|
@ -377,7 +559,15 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, unsupported_count, errors
|
||||
|
||||
|
||||
async def index_dropbox_files(
|
||||
|
|
@ -386,7 +576,7 @@ async def index_dropbox_files(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
) -> tuple[int, int, str | None, int]:
|
||||
"""Index Dropbox files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
|
|
@ -417,7 +607,7 @@ async def index_dropbox_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
|
|
@ -428,7 +618,7 @@ async def index_dropbox_files(
|
|||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
dropbox_client = DropboxClient(session, connector_id)
|
||||
|
|
@ -437,9 +627,13 @@ async def index_dropbox_files(
|
|||
max_files = indexing_options.get("max_files", 500)
|
||||
incremental_sync = indexing_options.get("incremental_sync", True)
|
||||
include_subfolders = indexing_options.get("include_subfolders", True)
|
||||
use_delta_sync = indexing_options.get("use_delta_sync", True)
|
||||
|
||||
folder_cursors: dict = connector.config.get("folder_cursors", {})
|
||||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
total_unsupported = 0
|
||||
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
|
|
@ -447,7 +641,7 @@ async def index_dropbox_files(
|
|||
(f.get("path", f.get("path_lower", f.get("id", ""))), f.get("name"))
|
||||
for f in selected_files
|
||||
]
|
||||
indexed, skipped, file_errors = await _index_selected_files(
|
||||
indexed, skipped, unsupported, file_errors = await _index_selected_files(
|
||||
dropbox_client,
|
||||
session,
|
||||
file_tuples,
|
||||
|
|
@ -459,6 +653,7 @@ async def index_dropbox_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsupported
|
||||
if file_errors:
|
||||
logger.warning(
|
||||
f"File indexing errors for connector {connector_id}: {file_errors}"
|
||||
|
|
@ -471,25 +666,66 @@ async def index_dropbox_files(
|
|||
)
|
||||
folder_name = folder.get("name", "Root")
|
||||
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_path,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
incremental_sync=incremental_sync,
|
||||
enable_summary=connector_enable_summary,
|
||||
saved_cursor = folder_cursors.get(folder_path)
|
||||
can_use_delta = (
|
||||
use_delta_sync and saved_cursor and connector.last_indexed_at
|
||||
)
|
||||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for folder {folder_name}")
|
||||
indexed, skipped, unsup, new_cursor = await _index_with_delta_sync(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
saved_cursor,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
folder_cursors[folder_path] = new_cursor
|
||||
total_unsupported += unsup
|
||||
else:
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped, unsup = await _index_full_scan(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_path,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
incremental_sync=incremental_sync,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_unsupported += unsup
|
||||
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
# Persist latest cursor for this folder
|
||||
try:
|
||||
latest_cursor, cursor_err = await dropbox_client.get_latest_cursor(
|
||||
folder_path
|
||||
)
|
||||
if latest_cursor and not cursor_err:
|
||||
folder_cursors[folder_path] = latest_cursor
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get latest cursor for {folder_path}: {e}")
|
||||
|
||||
# Persist folder cursors to connector config
|
||||
if folders:
|
||||
cfg = dict(connector.config)
|
||||
cfg["folder_cursors"] = folder_cursors
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
|
||||
if total_indexed > 0 or folders:
|
||||
await update_connector_last_indexed(session, connector, True)
|
||||
|
||||
|
|
@ -498,12 +734,18 @@ async def index_dropbox_files(
|
|||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Dropbox indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
{
|
||||
"files_processed": total_indexed,
|
||||
"files_skipped": total_skipped,
|
||||
"files_unsupported": total_unsupported,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Dropbox indexing completed: {total_indexed} indexed, {total_skipped} skipped"
|
||||
f"Dropbox indexing completed: {total_indexed} indexed, "
|
||||
f"{total_skipped} skipped, {total_unsupported} unsupported"
|
||||
)
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
return total_indexed, total_skipped, None, total_unsupported
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -514,7 +756,7 @@ async def index_dropbox_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}", 0
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -524,4 +766,4 @@ async def index_dropbox_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Dropbox files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index Dropbox files: {e!s}"
|
||||
return 0, 0, f"Failed to index Dropbox files: {e!s}", 0
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ from app.connectors.google_drive import (
|
|||
get_files_in_folder,
|
||||
get_start_page_token,
|
||||
)
|
||||
from app.connectors.google_drive.file_types import should_skip_file as skip_mime
|
||||
from app.connectors.google_drive.file_types import (
|
||||
is_google_workspace_file,
|
||||
should_skip_by_extension,
|
||||
should_skip_file as skip_mime,
|
||||
)
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
|
|
@ -34,6 +38,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
|||
PlaceholderInfo,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -77,6 +82,10 @@ async def _should_skip_file(
|
|||
|
||||
if skip_mime(mime_type):
|
||||
return True, "folder/shortcut"
|
||||
if not is_google_workspace_file(mime_type):
|
||||
ext_skip, unsup_ext = should_skip_by_extension(file_name)
|
||||
if ext_skip:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
||||
|
|
@ -327,6 +336,12 @@ async def _process_single_file(
|
|||
return 1, 0, 0
|
||||
return 0, 1, 0
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
estimated_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file_name, file.get("size")
|
||||
)
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
markdown, drive_metadata, error = await download_and_extract_content(
|
||||
drive_client, file
|
||||
)
|
||||
|
|
@ -363,6 +378,9 @@ async def _process_single_file(
|
|||
)
|
||||
await pipeline.index(document, connector_doc, user_llm)
|
||||
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
logger.info(f"Successfully indexed Google Drive file: {file_name}")
|
||||
return 1, 0, 0
|
||||
|
||||
|
|
@ -458,18 +476,24 @@ async def _index_selected_files(
|
|||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline.
|
||||
|
||||
Phase 1 (serial): fetch metadata + skip checks.
|
||||
Phase 2+3 (parallel): download, ETL, index via _download_and_index.
|
||||
|
||||
Returns (indexed_count, skipped_count, errors).
|
||||
Returns (indexed_count, skipped_count, unsupported_count, errors).
|
||||
"""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(drive_client, file_id)
|
||||
|
|
@ -480,12 +504,23 @@ async def _index_selected_files(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
await _create_drive_placeholders(
|
||||
|
|
@ -507,7 +542,15 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, unsupported_count, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -530,8 +573,11 @@ async def _index_full_scan(
|
|||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Full scan indexing of a folder.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
|
||||
|
|
@ -545,8 +591,15 @@ async def _index_full_scan(
|
|||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): collect files, run skip checks, track renames
|
||||
# ------------------------------------------------------------------
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_processed = 0
|
||||
files_to_download: list[dict] = []
|
||||
folders_to_process = [(folder_id, folder_name)]
|
||||
|
|
@ -587,12 +640,28 @@ async def _index_full_scan(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Google Drive full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
page_token = next_token
|
||||
|
|
@ -636,11 +705,20 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
|
|
@ -658,8 +736,11 @@ async def _index_with_delta_sync(
|
|||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Delta sync using change tracking."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Delta sync using change tracking.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting delta sync from token: {start_page_token[:20]}...",
|
||||
|
|
@ -679,15 +760,22 @@ async def _index_with_delta_sync(
|
|||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0
|
||||
return 0, 0, 0
|
||||
|
||||
logger.info(f"Processing {len(changes)} changes")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): handle removals, collect files for download
|
||||
# ------------------------------------------------------------------
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
|
|
@ -709,12 +797,28 @@ async def _index_with_delta_sync(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Google Drive delta sync, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -742,11 +846,20 @@ async def _index_with_delta_sync(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -766,8 +879,11 @@ async def index_google_drive_files(
|
|||
max_files: int = 500,
|
||||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index Google Drive files for a specific connector."""
|
||||
) -> tuple[int, int, str | None, int]:
|
||||
"""Index Google Drive files for a specific connector.
|
||||
|
||||
Returns (indexed, skipped, error_or_none, unsupported_count).
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="google_drive_files_indexing",
|
||||
|
|
@ -793,7 +909,7 @@ async def index_google_drive_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
@ -812,7 +928,7 @@ async def index_google_drive_files(
|
|||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
|
@ -827,6 +943,7 @@ async def index_google_drive_files(
|
|||
0,
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
0,
|
||||
)
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
|
|
@ -839,7 +956,7 @@ async def index_google_drive_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, {"error_type": "MissingParameter"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
target_folder_id = folder_id
|
||||
target_folder_name = folder_name or "Selected Folder"
|
||||
|
|
@ -850,9 +967,11 @@ async def index_google_drive_files(
|
|||
use_delta_sync and start_page_token and connector.last_indexed_at
|
||||
)
|
||||
|
||||
documents_unsupported = 0
|
||||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for connector {connector_id}")
|
||||
documents_indexed, documents_skipped = await _index_with_delta_sync(
|
||||
documents_indexed, documents_skipped, du = await _index_with_delta_sync(
|
||||
drive_client,
|
||||
session,
|
||||
connector,
|
||||
|
|
@ -868,8 +987,9 @@ async def index_google_drive_files(
|
|||
on_heartbeat_callback,
|
||||
connector_enable_summary,
|
||||
)
|
||||
documents_unsupported += du
|
||||
logger.info("Running reconciliation scan after delta sync")
|
||||
ri, rs = await _index_full_scan(
|
||||
ri, rs, ru = await _index_full_scan(
|
||||
drive_client,
|
||||
session,
|
||||
connector,
|
||||
|
|
@ -887,9 +1007,14 @@ async def index_google_drive_files(
|
|||
)
|
||||
documents_indexed += ri
|
||||
documents_skipped += rs
|
||||
documents_unsupported += ru
|
||||
else:
|
||||
logger.info(f"Using full scan for connector {connector_id}")
|
||||
documents_indexed, documents_skipped = await _index_full_scan(
|
||||
(
|
||||
documents_indexed,
|
||||
documents_skipped,
|
||||
documents_unsupported,
|
||||
) = await _index_full_scan(
|
||||
drive_client,
|
||||
session,
|
||||
connector,
|
||||
|
|
@ -924,14 +1049,17 @@ async def index_google_drive_files(
|
|||
{
|
||||
"files_processed": documents_indexed,
|
||||
"files_skipped": documents_skipped,
|
||||
"files_unsupported": documents_unsupported,
|
||||
"sync_type": "delta" if can_use_delta else "full",
|
||||
"folder": target_folder_name,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped"
|
||||
f"Google Drive indexing completed: {documents_indexed} indexed, "
|
||||
f"{documents_skipped} skipped, {documents_unsupported} unsupported"
|
||||
)
|
||||
return documents_indexed, documents_skipped, None
|
||||
|
||||
return documents_indexed, documents_skipped, None, documents_unsupported
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -942,7 +1070,7 @@ async def index_google_drive_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}", 0
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -952,7 +1080,7 @@ async def index_google_drive_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index Google Drive files: {e!s}"
|
||||
return 0, 0, f"Failed to index Google Drive files: {e!s}", 0
|
||||
|
||||
|
||||
async def index_google_drive_single_file(
|
||||
|
|
@ -1154,7 +1282,7 @@ async def index_google_drive_selected_files(
|
|||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
indexed, skipped, errors = await _index_selected_files(
|
||||
indexed, skipped, unsupported, errors = await _index_selected_files(
|
||||
drive_client,
|
||||
session,
|
||||
files,
|
||||
|
|
@ -1165,6 +1293,11 @@ async def index_google_drive_selected_files(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if unsupported > 0:
|
||||
file_text = "file was" if unsupported == 1 else "files were"
|
||||
unsup_msg = f"{unsupported} {file_text} not supported"
|
||||
errors.append(unsup_msg)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if errors:
|
||||
|
|
@ -1172,7 +1305,12 @@ async def index_google_drive_selected_files(
|
|||
log_entry,
|
||||
f"Batch file indexing completed with {len(errors)} error(s)",
|
||||
"; ".join(errors),
|
||||
{"indexed": indexed, "skipped": skipped, "error_count": len(errors)},
|
||||
{
|
||||
"indexed": indexed,
|
||||
"skipped": skipped,
|
||||
"unsupported": unsupported,
|
||||
"error_count": len(errors),
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -55,7 +56,10 @@ async def _should_skip_file(
|
|||
file_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
skip, unsup_ext = skip_item(file)
|
||||
if skip:
|
||||
if unsup_ext:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
return True, "folder/onenote/remote"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
|
@ -289,12 +293,18 @@ async def _index_selected_files(
|
|||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(onedrive_client, file_id)
|
||||
|
|
@ -305,12 +315,23 @@ async def _index_selected_files(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
|
|
@ -324,7 +345,15 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, unsupported_count, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -346,8 +375,11 @@ async def _index_full_scan(
|
|||
include_subfolders: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Full scan indexing of a folder.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
|
|
@ -358,8 +390,15 @@ async def _index_full_scan(
|
|||
},
|
||||
)
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
|
|
@ -378,11 +417,28 @@ async def _index_full_scan(
|
|||
for file in all_files[:max_files]:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during OneDrive full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -396,11 +452,20 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
|
|
@ -416,8 +481,11 @@ async def _index_with_delta_sync(
|
|||
max_files: int,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link)."""
|
||||
) -> tuple[int, int, int, str | None]:
|
||||
"""Delta sync using OneDrive change tracking.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count, new_delta_link).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
"Starting delta sync",
|
||||
|
|
@ -437,12 +505,19 @@ async def _index_with_delta_sync(
|
|||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0, new_delta_link
|
||||
return 0, 0, 0, new_delta_link
|
||||
|
||||
logger.info(f"Processing {len(changes)} delta changes")
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
|
|
@ -465,12 +540,28 @@ async def _index_with_delta_sync(
|
|||
|
||||
skip, msg = await _should_skip_file(session, change, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
change.get("name", ""), change.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during OneDrive delta sync, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(change)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -484,11 +575,20 @@ async def _index_with_delta_sync(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped, new_delta_link
|
||||
return indexed, skipped, unsupported_count, new_delta_link
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -502,7 +602,7 @@ async def index_onedrive_files(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
) -> tuple[int, int, str | None, int]:
|
||||
"""Index OneDrive files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
|
|
@ -529,7 +629,7 @@ async def index_onedrive_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
|
|
@ -540,7 +640,7 @@ async def index_onedrive_files(
|
|||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
|
|
@ -552,12 +652,13 @@ async def index_onedrive_files(
|
|||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
total_unsupported = 0
|
||||
|
||||
# Index selected individual files
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
file_tuples = [(f["id"], f.get("name")) for f in selected_files]
|
||||
indexed, skipped, _errors = await _index_selected_files(
|
||||
indexed, skipped, unsupported, _errors = await _index_selected_files(
|
||||
onedrive_client,
|
||||
session,
|
||||
file_tuples,
|
||||
|
|
@ -568,6 +669,7 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsupported
|
||||
|
||||
# Index selected folders
|
||||
folders = items_dict.get("folders", [])
|
||||
|
|
@ -581,7 +683,7 @@ async def index_onedrive_files(
|
|||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for folder {folder_name}")
|
||||
indexed, skipped, new_delta_link = await _index_with_delta_sync(
|
||||
indexed, skipped, unsup, new_delta_link = await _index_with_delta_sync(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -596,6 +698,7 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsup
|
||||
|
||||
if new_delta_link:
|
||||
await session.refresh(connector)
|
||||
|
|
@ -605,7 +708,7 @@ async def index_onedrive_files(
|
|||
flag_modified(connector, "config")
|
||||
|
||||
# Reconciliation full scan
|
||||
ri, rs = await _index_full_scan(
|
||||
ri, rs, ru = await _index_full_scan(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -621,9 +724,10 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += ri
|
||||
total_skipped += rs
|
||||
total_unsupported += ru
|
||||
else:
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
indexed, skipped, unsup = await _index_full_scan(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -639,6 +743,7 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsup
|
||||
|
||||
# Store new delta link for this folder
|
||||
_, new_delta_link, _ = await onedrive_client.get_delta(folder_id=folder_id)
|
||||
|
|
@ -657,12 +762,18 @@ async def index_onedrive_files(
|
|||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed OneDrive indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
{
|
||||
"files_processed": total_indexed,
|
||||
"files_skipped": total_skipped,
|
||||
"files_unsupported": total_unsupported,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped"
|
||||
f"OneDrive indexing completed: {total_indexed} indexed, "
|
||||
f"{total_skipped} skipped, {total_unsupported} unsupported"
|
||||
)
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
return total_indexed, total_skipped, None, total_unsupported
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -673,7 +784,7 @@ async def index_onedrive_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}", 0
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -683,4 +794,4 @@ async def index_onedrive_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index OneDrive files: {e!s}"
|
||||
return 0, 0, f"Failed to index OneDrive files: {e!s}", 0
|
||||
|
|
|
|||
|
|
@ -1,41 +1,17 @@
|
|||
"""
|
||||
Document processors module for background tasks.
|
||||
|
||||
This module provides a collection of document processors for different content types
|
||||
and sources. Each processor is responsible for handling a specific type of document
|
||||
processing task in the background.
|
||||
|
||||
Available processors:
|
||||
- Extension processor: Handle documents from browser extension
|
||||
- Markdown processor: Process markdown files
|
||||
- File processors: Handle files using different ETL services (Unstructured, LlamaCloud, Docling)
|
||||
- YouTube processor: Process YouTube videos and extract transcripts
|
||||
Content extraction is handled by ``app.etl_pipeline.EtlPipelineService``.
|
||||
This package keeps orchestration (save, notify, page-limit) and
|
||||
non-ETL processors (extension, markdown, youtube).
|
||||
"""
|
||||
|
||||
# Extension processor
|
||||
# File processors (backward-compatible re-exports from _save)
|
||||
from ._save import (
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
)
|
||||
from .extension_processor import add_extension_received_document
|
||||
|
||||
# Markdown processor
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
# YouTube processor
|
||||
from .youtube_processor import add_youtube_video_document
|
||||
|
||||
__all__ = [
|
||||
# Extension processing
|
||||
"add_extension_received_document",
|
||||
# File processing with different ETL services
|
||||
"add_received_file_document_using_docling",
|
||||
"add_received_file_document_using_llamacloud",
|
||||
"add_received_file_document_using_unstructured",
|
||||
# Markdown file processing
|
||||
"add_received_markdown_file_document",
|
||||
# YouTube video processing
|
||||
"add_youtube_video_document",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,74 +0,0 @@
|
|||
"""
|
||||
Constants for file document processing.
|
||||
|
||||
Centralizes file type classification, LlamaCloud retry configuration,
|
||||
and timeout calculation parameters.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File type classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MARKDOWN_EXTENSIONS = (".md", ".markdown", ".txt")
|
||||
AUDIO_EXTENSIONS = (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||
DIRECT_CONVERT_EXTENSIONS = (".csv", ".tsv", ".html", ".htm")
|
||||
|
||||
|
||||
class FileCategory(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
AUDIO = "audio"
|
||||
DIRECT_CONVERT = "direct_convert"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
def classify_file(filename: str) -> FileCategory:
|
||||
"""Classify a file by its extension into a processing category."""
|
||||
lower = filename.lower()
|
||||
if lower.endswith(MARKDOWN_EXTENSIONS):
|
||||
return FileCategory.MARKDOWN
|
||||
if lower.endswith(AUDIO_EXTENSIONS):
|
||||
return FileCategory.AUDIO
|
||||
if lower.endswith(DIRECT_CONVERT_EXTENSIONS):
|
||||
return FileCategory.DIRECT_CONVERT
|
||||
return FileCategory.DOCUMENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud retry configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LLAMACLOUD_MAX_RETRIES = 5
|
||||
LLAMACLOUD_BASE_DELAY = 10 # seconds (exponential backoff base)
|
||||
LLAMACLOUD_MAX_DELAY = 120 # max delay between retries (2 minutes)
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
||||
ssl.SSLError,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.LocalProtocolError,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
TimeoutError,
|
||||
OSError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout calculation constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW = (
|
||||
100 * 1024
|
||||
) # 100 KB/s (conservative for slow connections)
|
||||
MIN_UPLOAD_TIMEOUT = 120 # Minimum 2 minutes for any file
|
||||
MAX_UPLOAD_TIMEOUT = 1800 # Maximum 30 minutes for very large files
|
||||
BASE_JOB_TIMEOUT = 600 # 10 minutes base for job processing
|
||||
PER_PAGE_JOB_TIMEOUT = 60 # 1 minute per page for processing
|
||||
|
|
@ -4,8 +4,8 @@ Lossless file-to-markdown converters for text-based formats.
|
|||
These converters handle file types that can be faithfully represented as
|
||||
markdown without any external ETL/OCR service:
|
||||
|
||||
- CSV / TSV → markdown table (stdlib ``csv``)
|
||||
- HTML / HTM → markdown (``markdownify``)
|
||||
- CSV / TSV → markdown table (stdlib ``csv``)
|
||||
- HTML / HTM / XHTML → markdown (``markdownify``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -73,6 +73,7 @@ _CONVERTER_MAP: dict[str, Callable[..., str]] = {
|
|||
".tsv": tsv_to_markdown,
|
||||
".html": html_to_markdown,
|
||||
".htm": html_to_markdown,
|
||||
".xhtml": html_to_markdown,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,209 +0,0 @@
|
|||
"""
|
||||
ETL parsing strategies for different document processing services.
|
||||
|
||||
Provides parse functions for Unstructured, LlamaCloud, and Docling, along with
|
||||
LlamaCloud retry logic and dynamic timeout calculations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Log
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
||||
from ._constants import (
|
||||
LLAMACLOUD_BASE_DELAY,
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
LLAMACLOUD_MAX_RETRIES,
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
)
|
||||
from ._helpers import calculate_job_timeout, calculate_upload_timeout
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud parsing with retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def parse_with_llamacloud_retry(
|
||||
file_path: str,
|
||||
estimated_pages: int,
|
||||
task_logger: TaskLoggingService | None = None,
|
||||
log_entry: Log | None = None,
|
||||
):
|
||||
"""
|
||||
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
|
||||
|
||||
Uses dynamic timeout calculations based on file size and page count to handle
|
||||
very large files reliably.
|
||||
|
||||
Returns:
|
||||
LlamaParse result object
|
||||
|
||||
Raises:
|
||||
Exception: If all retries fail
|
||||
"""
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
file_size_bytes = os.path.getsize(file_path)
|
||||
file_size_mb = file_size_bytes / (1024 * 1024)
|
||||
|
||||
upload_timeout = calculate_upload_timeout(file_size_bytes)
|
||||
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
|
||||
|
||||
custom_timeout = httpx.Timeout(
|
||||
connect=120.0,
|
||||
read=upload_timeout,
|
||||
write=upload_timeout,
|
||||
pool=120.0,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
|
||||
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
|
||||
f"job_timeout={job_timeout:.0f}s"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
max_timeout=int(max(2000, job_timeout + upload_timeout)),
|
||||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
|
||||
custom_client=custom_client,
|
||||
)
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
if attempt > 1:
|
||||
logging.info(
|
||||
f"LlamaCloud upload succeeded on attempt {attempt} after "
|
||||
f"{len(attempt_errors)} failures"
|
||||
)
|
||||
return result
|
||||
|
||||
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
|
||||
if attempt < LLAMACLOUD_MAX_RETRIES:
|
||||
base_delay = min(
|
||||
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
)
|
||||
jitter = base_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = base_delay + jitter
|
||||
|
||||
if task_logger and log_entry:
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), "
|
||||
f"retrying in {delay:.0f}s",
|
||||
{
|
||||
"error_type": error_type,
|
||||
"error_message": error_msg,
|
||||
"attempt": attempt,
|
||||
"retry_delay": delay,
|
||||
"file_size_mb": round(file_size_mb, 1),
|
||||
"upload_timeout": upload_timeout,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
|
||||
f"{error_type}. File: {file_size_mb:.1f}MB. "
|
||||
f"Retrying in {delay:.0f}s..."
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
|
||||
f"attempts. File size: {file_size_mb:.1f}MB, "
|
||||
f"Pages: {estimated_pages}. "
|
||||
f"Errors: {'; '.join(attempt_errors)}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
raise last_exception or RuntimeError(
|
||||
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
|
||||
f"File size: {file_size_mb:.1f}MB"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-service parse functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def parse_with_unstructured(file_path: str):
|
||||
"""
|
||||
Parse a file using the Unstructured ETL service.
|
||||
|
||||
Returns:
|
||||
List of LangChain Document elements.
|
||||
"""
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
return await loader.aload()
|
||||
|
||||
|
||||
async def parse_with_docling(file_path: str, filename: str) -> str:
|
||||
"""
|
||||
Parse a file using the Docling ETL service (via the Docling service wrapper).
|
||||
|
||||
Returns:
|
||||
Markdown content string.
|
||||
"""
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
pdfminer_logger = getLogger("pdfminer")
|
||||
original_level = pdfminer_logger.level
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*Cannot set gray non-stroke color.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*invalid float value.*")
|
||||
pdfminer_logger.setLevel(ERROR)
|
||||
|
||||
try:
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
finally:
|
||||
pdfminer_logger.setLevel(original_level)
|
||||
|
||||
return result["content"]
|
||||
|
|
@ -11,13 +11,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
from ._constants import (
|
||||
BASE_JOB_TIMEOUT,
|
||||
MAX_UPLOAD_TIMEOUT,
|
||||
MIN_UPLOAD_TIMEOUT,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW,
|
||||
)
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
|
|
@ -198,21 +191,3 @@ async def update_document_from_connector(
|
|||
if "connector_id" in connector:
|
||||
document.connector_id = connector["connector_id"]
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout calculations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_upload_timeout(file_size_bytes: int) -> float:
|
||||
"""Calculate upload timeout based on file size (conservative for slow connections)."""
|
||||
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
|
||||
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
|
||||
|
||||
|
||||
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
|
||||
"""Calculate job processing timeout based on page count and file size."""
|
||||
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
|
||||
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
|
||||
return max(page_based_timeout, size_based_timeout)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,9 @@
|
|||
"""
|
||||
Unified document save/update logic for file processors.
|
||||
|
||||
Replaces the three nearly-identical ``add_received_file_document_using_*``
|
||||
functions with a single ``save_file_document`` function plus thin wrappers
|
||||
for backward compatibility.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -207,79 +202,3 @@ async def save_file_document(
|
|||
raise RuntimeError(
|
||||
f"Failed to process file document using {etl_service}: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward-compatible wrapper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def add_received_file_document_using_unstructured(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: list[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store a file document using the Unstructured service."""
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
markdown_content = await convert_document_to_markdown(
|
||||
unstructured_processed_elements
|
||||
)
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
markdown_content,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"UNSTRUCTURED",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
||||
|
||||
async def add_received_file_document_using_llamacloud(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store document content parsed by LlamaCloud."""
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
llamacloud_markdown_document,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"LLAMACLOUD",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
||||
|
||||
async def add_received_file_document_using_docling(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
docling_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store document content parsed by Docling."""
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
docling_markdown_document,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"DOCLING",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,8 @@
|
|||
"""
|
||||
File document processors orchestrating content extraction and indexing.
|
||||
|
||||
This module is the public entry point for file processing. It delegates to
|
||||
specialised sub-modules that each own a single concern:
|
||||
|
||||
- ``_constants`` — file type classification and configuration constants
|
||||
- ``_helpers`` — document deduplication, migration, connector helpers
|
||||
- ``_direct_converters`` — lossless file-to-markdown for csv/tsv/html
|
||||
- ``_etl`` — ETL parsing strategies (Unstructured, LlamaCloud, Docling)
|
||||
- ``_save`` — unified document creation / update logic
|
||||
Delegates content extraction to ``app.etl_pipeline.EtlPipelineService`` and
|
||||
keeps only orchestration concerns (notifications, logging, page limits, saving).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,38 +11,19 @@ import contextlib
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, Log, Notification
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
||||
from ._constants import FileCategory, classify_file
|
||||
from ._direct_converters import convert_file_directly
|
||||
from ._etl import (
|
||||
parse_with_docling,
|
||||
parse_with_llamacloud_retry,
|
||||
parse_with_unstructured,
|
||||
)
|
||||
from ._helpers import update_document_from_connector
|
||||
from ._save import (
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
save_file_document,
|
||||
)
|
||||
from ._save import save_file_document
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
# Re-export public API so existing ``from file_processors import …`` keeps working.
|
||||
__all__ = [
|
||||
"add_received_file_document_using_docling",
|
||||
"add_received_file_document_using_llamacloud",
|
||||
"add_received_file_document_using_unstructured",
|
||||
"parse_with_llamacloud_retry",
|
||||
"process_file_in_background",
|
||||
"process_file_in_background_with_document",
|
||||
"save_file_document",
|
||||
|
|
@ -142,35 +117,31 @@ async def _log_page_divergence(
|
|||
# ===================================================================
|
||||
|
||||
|
||||
async def _process_markdown_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Read a markdown / text file and create or update a document."""
|
||||
await _notify(ctx, "parsing", "Reading file")
|
||||
async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Extract content from a non-document file (plaintext/direct_convert/audio) via the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
await _notify(ctx, "parsing", "Processing file")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing markdown/text file: {ctx.filename}",
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||
f"Processing file: {ctx.filename}",
|
||||
{"processing_stage": "extracting"},
|
||||
)
|
||||
|
||||
with open(ctx.file_path, encoding="utf-8") as f:
|
||||
markdown_content = f.read()
|
||||
etl_result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=ctx.file_path, filename=ctx.filename)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Creating document from markdown content: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "creating_document",
|
||||
"content_length": len(markdown_content),
|
||||
},
|
||||
)
|
||||
|
||||
result = await add_received_markdown_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
markdown_content,
|
||||
etl_result.markdown_content,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
|
|
@ -181,179 +152,19 @@ async def _process_markdown_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
if result:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed markdown file: {ctx.filename}",
|
||||
f"Successfully processed file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "markdown",
|
||||
"file_type": etl_result.content_type,
|
||||
"etl_service": etl_result.etl_service,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Markdown file already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": "markdown"},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _process_direct_convert_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Convert a text-based file (csv/tsv/html) to markdown without ETL."""
|
||||
await _notify(ctx, "parsing", "Converting file")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Direct-converting file to markdown: {ctx.filename}",
|
||||
{"file_type": "direct_convert", "processing_stage": "converting"},
|
||||
)
|
||||
|
||||
markdown_content = convert_file_directly(ctx.file_path, ctx.filename)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Creating document from converted content: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "creating_document",
|
||||
"content_length": len(markdown_content),
|
||||
},
|
||||
)
|
||||
|
||||
result = await add_received_markdown_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
markdown_content,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
||||
if result:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully direct-converted file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "direct_convert",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Direct-converted file already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": "direct_convert"},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _process_audio_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Transcribe an audio file and create or update a document."""
|
||||
await _notify(ctx, "parsing", "Transcribing audio")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing audio file for transcription: {ctx.filename}",
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||
)
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
try:
|
||||
stt_result = stt_service.transcribe_file(ctx.file_path)
|
||||
transcribed_text = stt_result.get("text", "")
|
||||
if not transcribed_text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
transcribed_text = (
|
||||
f"# Transcription of {ctx.filename}\n\n{transcribed_text}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to transcribe audio file {ctx.filename}: {e!s}",
|
||||
) from e
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Local STT transcription completed: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "local_transcription_complete",
|
||||
"language": stt_result.get("language"),
|
||||
"confidence": stt_result.get("language_probability"),
|
||||
"duration": stt_result.get("duration"),
|
||||
},
|
||||
)
|
||||
else:
|
||||
from litellm import atranscription
|
||||
|
||||
with open(ctx.file_path, "rb") as audio_file:
|
||||
transcription_kwargs: dict = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
transcription_kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
|
||||
transcription_response = await atranscription(**transcription_kwargs)
|
||||
transcribed_text = transcription_response.get("text", "")
|
||||
if not transcribed_text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
|
||||
transcribed_text = f"# Transcription of {ctx.filename}\n\n{transcribed_text}"
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Transcription completed, creating document: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "transcription_complete",
|
||||
"transcript_length": len(transcribed_text),
|
||||
},
|
||||
)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
result = await add_received_markdown_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
transcribed_text,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
||||
if result:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully transcribed and processed audio file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "audio",
|
||||
"transcript_length": len(transcribed_text),
|
||||
"stt_service": stt_service_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Audio file transcript already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": "audio"},
|
||||
f"File already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": etl_result.content_type},
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
@ -363,279 +174,10 @@ async def _process_audio_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _etl_unstructured(
|
||||
ctx: _ProcessingContext,
|
||||
page_limit_service,
|
||||
estimated_pages: int,
|
||||
) -> Document | None:
|
||||
"""Parse and save via the Unstructured ETL service."""
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing file with Unstructured ETL: {ctx.filename}",
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
"processing_stage": "loading",
|
||||
},
|
||||
)
|
||||
|
||||
docs = await parse_with_unstructured(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking", chunks_count=len(docs))
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Unstructured ETL completed, creating document: {ctx.filename}",
|
||||
{"processing_stage": "etl_complete", "elements_count": len(docs)},
|
||||
)
|
||||
|
||||
actual_pages = page_limit_service.estimate_pages_from_elements(docs)
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await _log_page_divergence(
|
||||
ctx.task_logger,
|
||||
ctx.log_entry,
|
||||
ctx.filename,
|
||||
estimated_pages,
|
||||
actual_pages,
|
||||
final_pages,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
result = await add_received_file_document_using_unstructured(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
docs,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file with Unstructured: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
"pages_processed": final_pages,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _etl_llamacloud(
|
||||
ctx: _ProcessingContext,
|
||||
page_limit_service,
|
||||
estimated_pages: int,
|
||||
) -> Document | None:
|
||||
"""Parse and save via the LlamaCloud ETL service."""
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing file with LlamaCloud ETL: {ctx.filename}",
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"processing_stage": "parsing",
|
||||
"estimated_pages": estimated_pages,
|
||||
},
|
||||
)
|
||||
|
||||
raw_result = await parse_with_llamacloud_retry(
|
||||
file_path=ctx.file_path,
|
||||
estimated_pages=estimated_pages,
|
||||
task_logger=ctx.task_logger,
|
||||
log_entry=ctx.log_entry,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
markdown_documents = await raw_result.aget_markdown_documents(split_by_page=False)
|
||||
|
||||
await _notify(ctx, "chunking", chunks_count=len(markdown_documents))
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"LlamaCloud parsing completed, creating documents: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "parsing_complete",
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
|
||||
if not markdown_documents:
|
||||
await ctx.task_logger.log_task_failure(
|
||||
ctx.log_entry,
|
||||
f"LlamaCloud parsing returned no documents: {ctx.filename}",
|
||||
"ETL service returned empty document list",
|
||||
{"error_type": "EmptyDocumentList", "etl_service": "LLAMACLOUD"},
|
||||
)
|
||||
raise ValueError(f"LlamaCloud parsing returned no documents for {ctx.filename}")
|
||||
|
||||
actual_pages = page_limit_service.estimate_pages_from_markdown(markdown_documents)
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await _log_page_divergence(
|
||||
ctx.task_logger,
|
||||
ctx.log_entry,
|
||||
ctx.filename,
|
||||
estimated_pages,
|
||||
actual_pages,
|
||||
final_pages,
|
||||
)
|
||||
|
||||
any_created = False
|
||||
last_doc: Document | None = None
|
||||
|
||||
for doc in markdown_documents:
|
||||
doc_result = await add_received_file_document_using_llamacloud(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
llamacloud_markdown_document=doc.text,
|
||||
search_space_id=ctx.search_space_id,
|
||||
user_id=ctx.user_id,
|
||||
connector=ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
if doc_result:
|
||||
any_created = True
|
||||
last_doc = doc_result
|
||||
|
||||
if any_created:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(last_doc, ctx.connector, ctx.session)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file with LlamaCloud: {ctx.filename}",
|
||||
{
|
||||
"document_id": last_doc.id,
|
||||
"content_hash": last_doc.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"pages_processed": final_pages,
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
return last_doc
|
||||
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _etl_docling(
|
||||
ctx: _ProcessingContext,
|
||||
page_limit_service,
|
||||
estimated_pages: int,
|
||||
) -> Document | None:
|
||||
"""Parse and save via the Docling ETL service."""
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing file with Docling ETL: {ctx.filename}",
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
"processing_stage": "parsing",
|
||||
},
|
||||
)
|
||||
|
||||
content = await parse_with_docling(ctx.file_path, ctx.filename)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Docling parsing completed, creating document: {ctx.filename}",
|
||||
{"processing_stage": "parsing_complete", "content_length": len(content)},
|
||||
)
|
||||
|
||||
actual_pages = page_limit_service.estimate_pages_from_content_length(len(content))
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await _log_page_divergence(
|
||||
ctx.task_logger,
|
||||
ctx.log_entry,
|
||||
ctx.filename,
|
||||
estimated_pages,
|
||||
actual_pages,
|
||||
final_pages,
|
||||
)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
|
||||
result = await add_received_file_document_using_docling(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
docling_markdown_document=content,
|
||||
search_space_id=ctx.search_space_id,
|
||||
user_id=ctx.user_id,
|
||||
connector=ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file with Docling: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
"pages_processed": final_pages,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Route a document file to the configured ETL service."""
|
||||
"""Route a document file to the configured ETL service via the unified pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.services.page_limit_service import PageLimitExceededError, PageLimitService
|
||||
|
||||
page_limit_service = PageLimitService(ctx.session)
|
||||
|
|
@ -665,16 +207,60 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
os.unlink(ctx.file_path)
|
||||
raise HTTPException(status_code=403, detail=str(e)) from e
|
||||
|
||||
etl_dispatch = {
|
||||
"UNSTRUCTURED": _etl_unstructured,
|
||||
"LLAMACLOUD": _etl_llamacloud,
|
||||
"DOCLING": _etl_docling,
|
||||
}
|
||||
handler = etl_dispatch.get(app_config.ETL_SERVICE)
|
||||
if handler is None:
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
|
||||
return await handler(ctx, page_limit_service, estimated_pages)
|
||||
etl_result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=ctx.file_path,
|
||||
filename=ctx.filename,
|
||||
estimated_pages=estimated_pages,
|
||||
)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
|
||||
result = await save_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
etl_result.markdown_content,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
etl_result.etl_service,
|
||||
ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": etl_result.etl_service,
|
||||
"pages_processed": estimated_pages,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": etl_result.etl_service,
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ===================================================================
|
||||
|
|
@ -706,15 +292,16 @@ async def process_file_in_background(
|
|||
)
|
||||
|
||||
try:
|
||||
category = classify_file(filename)
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
FileCategory as EtlFileCategory,
|
||||
classify_file as etl_classify,
|
||||
)
|
||||
|
||||
if category == FileCategory.MARKDOWN:
|
||||
return await _process_markdown_upload(ctx)
|
||||
if category == FileCategory.DIRECT_CONVERT:
|
||||
return await _process_direct_convert_upload(ctx)
|
||||
if category == FileCategory.AUDIO:
|
||||
return await _process_audio_upload(ctx)
|
||||
return await _process_document_upload(ctx)
|
||||
category = etl_classify(filename)
|
||||
|
||||
if category == EtlFileCategory.DOCUMENT:
|
||||
return await _process_document_upload(ctx)
|
||||
return await _process_non_document_upload(ctx)
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
|
@ -758,201 +345,64 @@ async def _extract_file_content(
|
|||
Returns:
|
||||
Tuple of (markdown_content, etl_service_name).
|
||||
"""
|
||||
category = classify_file(filename)
|
||||
|
||||
if category == FileCategory.MARKDOWN:
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Reading file",
|
||||
)
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing markdown/text file: {filename}",
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||
)
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
return content, "MARKDOWN"
|
||||
|
||||
if category == FileCategory.DIRECT_CONVERT:
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Converting file",
|
||||
)
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Direct-converting file to markdown: {filename}",
|
||||
{"file_type": "direct_convert", "processing_stage": "converting"},
|
||||
)
|
||||
content = convert_file_directly(file_path, filename)
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
return content, "DIRECT_CONVERT"
|
||||
|
||||
if category == FileCategory.AUDIO:
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Transcribing audio",
|
||||
)
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing audio file for transcription: {filename}",
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||
)
|
||||
transcribed_text = await _transcribe_audio(file_path, filename)
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
return transcribed_text, "AUDIO_TRANSCRIPTION"
|
||||
|
||||
# Document file — use ETL service
|
||||
return await _extract_document_content(
|
||||
file_path,
|
||||
filename,
|
||||
session,
|
||||
user_id,
|
||||
task_logger,
|
||||
log_entry,
|
||||
notification,
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
FileCategory,
|
||||
classify_file as etl_classify,
|
||||
)
|
||||
|
||||
|
||||
async def _transcribe_audio(file_path: str, filename: str) -> str:
|
||||
"""Transcribe an audio file and return formatted markdown text."""
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
result = stt_service.transcribe_file(file_path)
|
||||
text = result.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
else:
|
||||
from litellm import atranscription
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
response = await atranscription(**kwargs)
|
||||
text = response.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
|
||||
async def _extract_document_content(
|
||||
file_path: str,
|
||||
filename: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: Log,
|
||||
notification: Notification | None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Parse a document file via the configured ETL service.
|
||||
|
||||
Returns:
|
||||
Tuple of (markdown_content, etl_service_name).
|
||||
"""
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
|
||||
try:
|
||||
estimated_pages = page_limit_service.estimate_pages_before_processing(file_path)
|
||||
except Exception:
|
||||
file_size = os.path.getsize(file_path)
|
||||
estimated_pages = max(1, file_size // (80 * 1024))
|
||||
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
etl_service = app_config.ETL_SERVICE
|
||||
markdown_content: str | None = None
|
||||
category = etl_classify(filename)
|
||||
estimated_pages = 0
|
||||
|
||||
if notification:
|
||||
stage_messages = {
|
||||
FileCategory.PLAINTEXT: "Reading file",
|
||||
FileCategory.DIRECT_CONVERT: "Converting file",
|
||||
FileCategory.AUDIO: "Transcribing audio",
|
||||
FileCategory.UNSUPPORTED: "Unsupported file type",
|
||||
FileCategory.DOCUMENT: "Extracting content",
|
||||
}
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Extracting content",
|
||||
stage_message=stage_messages.get(category, "Processing"),
|
||||
)
|
||||
|
||||
if etl_service == "UNSTRUCTURED":
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing {category.value} file: {filename}",
|
||||
{"file_type": category.value, "processing_stage": "extracting"},
|
||||
)
|
||||
|
||||
docs = await parse_with_unstructured(file_path)
|
||||
markdown_content = await convert_document_to_markdown(docs)
|
||||
actual_pages = page_limit_service.estimate_pages_from_elements(docs)
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
if category == FileCategory.DOCUMENT:
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
elif etl_service == "LLAMACLOUD":
|
||||
raw_result = await parse_with_llamacloud_retry(
|
||||
page_limit_service = PageLimitService(session)
|
||||
estimated_pages = _estimate_pages_safe(page_limit_service, file_path)
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
estimated_pages=estimated_pages,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
)
|
||||
markdown_documents = await raw_result.aget_markdown_documents(
|
||||
split_by_page=False
|
||||
)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud parsing returned no documents: {filename}")
|
||||
markdown_content = markdown_documents[0].text
|
||||
)
|
||||
|
||||
if category == FileCategory.DOCUMENT:
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
|
||||
elif etl_service == "DOCLING":
|
||||
getLogger("docling.pipeline.base_pipeline").setLevel(ERROR)
|
||||
getLogger("docling.document_converter").setLevel(ERROR)
|
||||
getLogger("docling_core.transforms.chunker.hierarchical_chunker").setLevel(
|
||||
ERROR
|
||||
)
|
||||
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
markdown_content = result.document.export_to_markdown()
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {etl_service}")
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
|
||||
if not markdown_content:
|
||||
if not result.markdown_content:
|
||||
raise RuntimeError(f"Failed to extract content from file: {filename}")
|
||||
|
||||
return markdown_content, etl_service
|
||||
return result.markdown_content, result.etl_service
|
||||
|
||||
|
||||
async def process_file_in_background_with_document(
|
||||
|
|
|
|||
107
surfsense_backend/app/utils/document_versioning.py
Normal file
107
surfsense_backend/app/utils/document_versioning.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Document versioning: snapshot creation and cleanup.
|
||||
|
||||
Rules:
|
||||
- 30-minute debounce window: if the latest version was created < 30 min ago,
|
||||
overwrite it instead of creating a new row.
|
||||
- Maximum 20 versions per document.
|
||||
- Versions older than 90 days are cleaned up.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentVersion
|
||||
|
||||
MAX_VERSIONS_PER_DOCUMENT = 20
|
||||
DEBOUNCE_MINUTES = 30
|
||||
RETENTION_DAYS = 90
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def create_version_snapshot(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
) -> DocumentVersion | None:
|
||||
"""Snapshot the document's current state into a DocumentVersion row.
|
||||
|
||||
Returns the created/updated DocumentVersion, or None if nothing was done.
|
||||
"""
|
||||
now = _now()
|
||||
|
||||
latest = (
|
||||
await session.execute(
|
||||
select(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document.id)
|
||||
.order_by(DocumentVersion.version_number.desc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if latest is not None:
|
||||
age = now - latest.created_at.replace(tzinfo=UTC)
|
||||
if age < timedelta(minutes=DEBOUNCE_MINUTES):
|
||||
latest.source_markdown = document.source_markdown
|
||||
latest.content_hash = document.content_hash
|
||||
latest.title = document.title
|
||||
latest.created_at = now
|
||||
await session.flush()
|
||||
return latest
|
||||
|
||||
max_num = (
|
||||
await session.execute(
|
||||
select(func.coalesce(func.max(DocumentVersion.version_number), 0)).where(
|
||||
DocumentVersion.document_id == document.id
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
version = DocumentVersion(
|
||||
document_id=document.id,
|
||||
version_number=max_num + 1,
|
||||
source_markdown=document.source_markdown,
|
||||
content_hash=document.content_hash,
|
||||
title=document.title,
|
||||
created_at=now,
|
||||
)
|
||||
session.add(version)
|
||||
await session.flush()
|
||||
|
||||
# Cleanup: remove versions older than 90 days
|
||||
cutoff = now - timedelta(days=RETENTION_DAYS)
|
||||
await session.execute(
|
||||
delete(DocumentVersion).where(
|
||||
DocumentVersion.document_id == document.id,
|
||||
DocumentVersion.created_at < cutoff,
|
||||
)
|
||||
)
|
||||
|
||||
# Cleanup: cap at MAX_VERSIONS_PER_DOCUMENT
|
||||
count = (
|
||||
await session.execute(
|
||||
select(func.count())
|
||||
.select_from(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document.id)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if count > MAX_VERSIONS_PER_DOCUMENT:
|
||||
excess = count - MAX_VERSIONS_PER_DOCUMENT
|
||||
oldest_ids_result = await session.execute(
|
||||
select(DocumentVersion.id)
|
||||
.where(DocumentVersion.document_id == document.id)
|
||||
.order_by(DocumentVersion.version_number.asc())
|
||||
.limit(excess)
|
||||
)
|
||||
oldest_ids = [row[0] for row in oldest_ids_result.all()]
|
||||
if oldest_ids:
|
||||
await session.execute(
|
||||
delete(DocumentVersion).where(DocumentVersion.id.in_(oldest_ids))
|
||||
)
|
||||
|
||||
await session.flush()
|
||||
return version
|
||||
153
surfsense_backend/app/utils/file_extensions.py
Normal file
153
surfsense_backend/app/utils/file_extensions.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Per-parser document extension sets for the ETL pipeline.
|
||||
|
||||
Every consumer (file_classifier, connector-level skip checks, ETL pipeline
|
||||
validation) imports from here so there is a single source of truth.
|
||||
|
||||
Extensions already covered by PLAINTEXT_EXTENSIONS, AUDIO_EXTENSIONS, or
|
||||
DIRECT_CONVERT_EXTENSIONS in file_classifier are NOT repeated here -- these
|
||||
sets are exclusively for the "document" ETL path (Docling / LlamaParse /
|
||||
Unstructured).
|
||||
"""
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-parser document extension sets (from official documentation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCLING_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".tiff",
|
||||
".tif",
|
||||
".bmp",
|
||||
".webp",
|
||||
}
|
||||
)
|
||||
|
||||
LLAMAPARSE_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".doc",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".rtf",
|
||||
".epub",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".tif",
|
||||
".webp",
|
||||
".svg",
|
||||
".odt",
|
||||
".ods",
|
||||
".odp",
|
||||
".hwp",
|
||||
".hwpx",
|
||||
}
|
||||
)
|
||||
|
||||
UNSTRUCTURED_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".doc",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".tif",
|
||||
".heic",
|
||||
".rtf",
|
||||
".epub",
|
||||
".odt",
|
||||
".eml",
|
||||
".msg",
|
||||
".p7s",
|
||||
}
|
||||
)
|
||||
|
||||
AZURE_DI_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".tif",
|
||||
".heif",
|
||||
}
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Union (used by classify_file for routing) + service lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCUMENT_EXTENSIONS: frozenset[str] = (
|
||||
DOCLING_DOCUMENT_EXTENSIONS
|
||||
| LLAMAPARSE_DOCUMENT_EXTENSIONS
|
||||
| UNSTRUCTURED_DOCUMENT_EXTENSIONS
|
||||
| AZURE_DI_DOCUMENT_EXTENSIONS
|
||||
)
|
||||
|
||||
_SERVICE_MAP: dict[str, frozenset[str]] = {
|
||||
"DOCLING": DOCLING_DOCUMENT_EXTENSIONS,
|
||||
"LLAMACLOUD": LLAMAPARSE_DOCUMENT_EXTENSIONS,
|
||||
"UNSTRUCTURED": UNSTRUCTURED_DOCUMENT_EXTENSIONS,
|
||||
}
|
||||
|
||||
|
||||
def get_document_extensions_for_service(etl_service: str | None) -> frozenset[str]:
|
||||
"""Return the document extensions supported by *etl_service*.
|
||||
|
||||
When *etl_service* is ``LLAMACLOUD`` and Azure Document Intelligence
|
||||
credentials are configured, the set is dynamically expanded to include
|
||||
Azure DI's supported extensions (e.g. ``.heif``).
|
||||
|
||||
Falls back to the full union when the service is ``None`` or unknown.
|
||||
"""
|
||||
extensions = _SERVICE_MAP.get(etl_service or "", DOCUMENT_EXTENSIONS)
|
||||
if etl_service == "LLAMACLOUD":
|
||||
from app.config import config as app_config
|
||||
|
||||
if getattr(app_config, "AZURE_DI_ENDPOINT", None) and getattr(
|
||||
app_config, "AZURE_DI_KEY", None
|
||||
):
|
||||
extensions = extensions | AZURE_DI_DOCUMENT_EXTENSIONS
|
||||
return extensions
|
||||
|
||||
|
||||
def is_supported_document_extension(filename: str) -> bool:
|
||||
"""Return True if the file's extension is in the supported document set."""
|
||||
suffix = PurePosixPath(filename).suffix.lower()
|
||||
return suffix in DOCUMENT_EXTENSIONS
|
||||
|
|
@ -11,6 +11,8 @@ import hmac
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from random import SystemRandom
|
||||
from string import ascii_letters, digits
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
|
@ -18,6 +20,25 @@ from fastapi import HTTPException
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PKCE_CHARS = ascii_letters + digits + "-._~"
|
||||
_PKCE_RNG = SystemRandom()
|
||||
|
||||
|
||||
def generate_code_verifier(length: int = 128) -> str:
|
||||
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
|
||||
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
|
||||
|
||||
|
||||
def generate_pkce_pair(length: int = 128) -> tuple[str, str]:
|
||||
"""Generate a PKCE code_verifier and its S256 code_challenge."""
|
||||
verifier = generate_code_verifier(length)
|
||||
challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages secure OAuth state parameters with HMAC signatures."""
|
||||
|
|
|
|||
|
|
@ -46,8 +46,6 @@ dependencies = [
|
|||
"redis>=5.2.1",
|
||||
"firecrawl-py>=4.9.0",
|
||||
"boto3>=1.35.0",
|
||||
"litellm>=1.80.10",
|
||||
"langchain-litellm>=0.3.5",
|
||||
"fake-useragent>=2.2.0",
|
||||
"trafilatura>=2.0.0",
|
||||
"fastapi-users[oauth,sqlalchemy]>=15.0.3",
|
||||
|
|
@ -75,6 +73,9 @@ dependencies = [
|
|||
"langchain-community>=0.4.1",
|
||||
"deepagents>=0.4.12",
|
||||
"stripe>=15.0.0",
|
||||
"azure-ai-documentintelligence>=1.0.2",
|
||||
"litellm>=1.83.0",
|
||||
"langchain-litellm>=0.6.4",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
Prerequisites: PostgreSQL + pgvector only.
|
||||
|
||||
External system boundaries are mocked:
|
||||
- ETL parsing — LlamaParse (external API) and Docling (heavy library)
|
||||
- LLM summarization, text embedding, text chunking (external APIs)
|
||||
- Redis heartbeat (external infrastructure)
|
||||
- Task dispatch is swapped via DI (InlineTaskDispatcher)
|
||||
|
|
@ -11,6 +12,7 @@ External system boundaries are mocked:
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
|
@ -298,3 +300,59 @@ def _mock_redis_heartbeat(monkeypatch):
|
|||
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
|
||||
_MOCK_ETL_MARKDOWN = "# Mocked Document\n\nThis is mocked ETL content."
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_etl_parsing(monkeypatch):
|
||||
"""Mock ETL parsing services — LlamaParse and Docling are external boundaries.
|
||||
|
||||
Preserves the real contract: empty/corrupt files raise an error just like
|
||||
the actual services would, so tests covering failure paths keep working.
|
||||
"""
|
||||
|
||||
def _reject_empty(file_path: str) -> None:
|
||||
if os.path.getsize(file_path) == 0:
|
||||
raise RuntimeError(f"Cannot parse empty file: {file_path}")
|
||||
|
||||
# -- LlamaParse mock (external API) --------------------------------
|
||||
|
||||
async def _fake_llamacloud_parse(file_path: str, estimated_pages: int) -> str:
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.etl_pipeline.parsers.llamacloud.parse_with_llamacloud",
|
||||
_fake_llamacloud_parse,
|
||||
)
|
||||
|
||||
# -- Docling mock (heavy library boundary) -------------------------
|
||||
|
||||
async def _fake_docling_parse(file_path: str, filename: str) -> str:
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.etl_pipeline.parsers.docling.parse_with_docling",
|
||||
_fake_docling_parse,
|
||||
)
|
||||
|
||||
class _FakeDoclingResult:
|
||||
class Document:
|
||||
@staticmethod
|
||||
def export_to_markdown():
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
document = Document()
|
||||
|
||||
class _FakeDocumentConverter:
|
||||
def convert(self, file_path):
|
||||
_reject_empty(file_path)
|
||||
return _FakeDoclingResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"docling.document_converter.DocumentConverter",
|
||||
_FakeDocumentConverter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ async def test_composio_connector_without_account_id_returns_error(
|
|||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
count, _skipped, error = await index_google_drive_files(
|
||||
count, _skipped, error, _unsupported = await index_google_drive_files(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
167
surfsense_backend/tests/integration/test_document_versioning.py
Normal file
167
surfsense_backend/tests/integration/test_document_versioning.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""Integration tests for document versioning snapshot + cleanup."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType, DocumentVersion, SearchSpace, User
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_document(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
) -> Document:
|
||||
doc = Document(
|
||||
title="Test Doc",
|
||||
document_type=DocumentType.LOCAL_FOLDER_FILE,
|
||||
document_metadata={},
|
||||
content="Summary of test doc.",
|
||||
content_hash="abc123",
|
||||
unique_identifier_hash="local_folder:test-folder:test.md",
|
||||
source_markdown="# Test\n\nOriginal content.",
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=db_user.id,
|
||||
)
|
||||
db_session.add(doc)
|
||||
await db_session.flush()
|
||||
return doc
|
||||
|
||||
|
||||
async def _version_count(session: AsyncSession, document_id: int) -> int:
|
||||
result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document_id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def _get_versions(
|
||||
session: AsyncSession, document_id: int
|
||||
) -> list[DocumentVersion]:
|
||||
result = await session.execute(
|
||||
select(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document_id)
|
||||
.order_by(DocumentVersion.version_number)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
class TestCreateVersionSnapshot:
|
||||
"""V1-V5: TDD slices for create_version_snapshot."""
|
||||
|
||||
async def test_v1_creates_first_version(self, db_session, db_document):
|
||||
"""V1: First snapshot creates version 1 with the document's current state."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert len(versions) == 1
|
||||
assert versions[0].version_number == 1
|
||||
assert versions[0].source_markdown == "# Test\n\nOriginal content."
|
||||
assert versions[0].content_hash == "abc123"
|
||||
assert versions[0].title == "Test Doc"
|
||||
assert versions[0].document_id == db_document.id
|
||||
|
||||
async def test_v2_creates_version_2_after_30_min(
|
||||
self, db_session, db_document, monkeypatch
|
||||
):
|
||||
"""V2: After 30+ minutes, a new version is created (not overwritten)."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
t0 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t0)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
# Simulate content change and time passing
|
||||
db_document.source_markdown = "# Test\n\nUpdated content."
|
||||
db_document.content_hash = "def456"
|
||||
t1 = t0 + timedelta(minutes=31)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t1)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert len(versions) == 2
|
||||
assert versions[0].version_number == 1
|
||||
assert versions[1].version_number == 2
|
||||
assert versions[1].source_markdown == "# Test\n\nUpdated content."
|
||||
|
||||
async def test_v3_overwrites_within_30_min(
|
||||
self, db_session, db_document, monkeypatch
|
||||
):
|
||||
"""V3: Within 30 minutes, the latest version is overwritten."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
t0 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t0)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
count_after_first = await _version_count(db_session, db_document.id)
|
||||
assert count_after_first == 1
|
||||
|
||||
# Simulate quick edit within 30 minutes
|
||||
db_document.source_markdown = "# Test\n\nQuick edit."
|
||||
db_document.content_hash = "quick123"
|
||||
t1 = t0 + timedelta(minutes=10)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t1)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
count_after_second = await _version_count(db_session, db_document.id)
|
||||
assert count_after_second == 1 # still 1, not 2
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert versions[0].source_markdown == "# Test\n\nQuick edit."
|
||||
assert versions[0].content_hash == "quick123"
|
||||
|
||||
async def test_v4_cleanup_90_day_old_versions(
|
||||
self, db_session, db_document, monkeypatch
|
||||
):
|
||||
"""V4: Versions older than 90 days are cleaned up."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
base = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Create 5 versions spread across time: 3 older than 90 days, 2 recent
|
||||
for i in range(5):
|
||||
db_document.source_markdown = f"Content v{i + 1}"
|
||||
db_document.content_hash = f"hash_{i + 1}"
|
||||
t = base + timedelta(days=i) if i < 3 else base + timedelta(days=100 + i)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda _t=t: _t)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
# Now trigger cleanup from a "current" time that makes the first 3 versions > 90 days old
|
||||
now = base + timedelta(days=200)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: now)
|
||||
db_document.source_markdown = "Content v6"
|
||||
db_document.content_hash = "hash_6"
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
# The first 3 (old) should be cleaned up; versions 4, 5, 6 remain
|
||||
for v in versions:
|
||||
age = now - v.created_at.replace(tzinfo=UTC)
|
||||
assert age <= timedelta(days=90), f"Version {v.version_number} is too old"
|
||||
|
||||
async def test_v5_cap_at_20_versions(self, db_session, db_document, monkeypatch):
|
||||
"""V5: More than 20 versions triggers cap — oldest gets deleted."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
base = datetime(2025, 6, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Create 21 versions (all within 90 days, each 31 min apart)
|
||||
for i in range(21):
|
||||
db_document.source_markdown = f"Content v{i + 1}"
|
||||
db_document.content_hash = f"hash_{i + 1}"
|
||||
t = base + timedelta(minutes=31 * i)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda _t=t: _t)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert len(versions) == 20
|
||||
# The lowest version_number should be 2 (version 1 was the oldest and got capped)
|
||||
assert versions[0].version_number == 2
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Tests that each cloud connector's download_and_extract_content correctly
|
||||
produces markdown from a real file via the unified ETL pipeline.
|
||||
|
||||
Only the cloud client is mocked (system boundary). The ETL pipeline runs for
|
||||
real so we know the full path from "cloud gives us bytes" to "we get markdown
|
||||
back" actually works.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_TXT_CONTENT = "Hello from the cloud connector test."
|
||||
_CSV_CONTENT = "name,age\nAlice,30\nBob,25\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _write_file(dest_path: str, content: str) -> None:
|
||||
"""Simulate a cloud client writing downloaded bytes to disk."""
|
||||
with open(dest_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def _make_download_side_effect(content: str):
|
||||
"""Return an async side-effect that writes *content* to the dest path
|
||||
and returns ``None`` (success)."""
|
||||
|
||||
async def _side_effect(*args):
|
||||
dest_path = args[-1]
|
||||
await _write_file(dest_path, content)
|
||||
return None
|
||||
|
||||
return _side_effect
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Google Drive
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestGoogleDriveContentExtraction:
|
||||
async def test_txt_file_returns_markdown(self):
|
||||
from app.connectors.google_drive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_TXT_CONTENT),
|
||||
)
|
||||
|
||||
file = {"id": "f1", "name": "notes.txt", "mimeType": "text/plain"}
|
||||
|
||||
markdown, metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert _TXT_CONTENT in markdown
|
||||
assert metadata["google_drive_file_id"] == "f1"
|
||||
assert metadata["google_drive_file_name"] == "notes.txt"
|
||||
|
||||
async def test_csv_file_returns_markdown_table(self):
|
||||
from app.connectors.google_drive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_CSV_CONTENT),
|
||||
)
|
||||
|
||||
file = {"id": "f2", "name": "data.csv", "mimeType": "text/csv"}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert "Alice" in markdown
|
||||
assert "Bob" in markdown
|
||||
assert "|" in markdown
|
||||
|
||||
async def test_download_error_returns_error_message(self):
|
||||
from app.connectors.google_drive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(return_value="Network timeout")
|
||||
|
||||
file = {"id": "f3", "name": "doc.txt", "mimeType": "text/plain"}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert markdown is None
|
||||
assert error == "Network timeout"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# OneDrive
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestOneDriveContentExtraction:
|
||||
async def test_txt_file_returns_markdown(self):
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_TXT_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "od-1",
|
||||
"name": "report.txt",
|
||||
"file": {"mimeType": "text/plain"},
|
||||
}
|
||||
|
||||
markdown, metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert _TXT_CONTENT in markdown
|
||||
assert metadata["onedrive_file_id"] == "od-1"
|
||||
assert metadata["onedrive_file_name"] == "report.txt"
|
||||
|
||||
async def test_csv_file_returns_markdown_table(self):
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_CSV_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "od-2",
|
||||
"name": "data.csv",
|
||||
"file": {"mimeType": "text/csv"},
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert "Alice" in markdown
|
||||
assert "|" in markdown
|
||||
|
||||
async def test_download_error_returns_error_message(self):
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(return_value="403 Forbidden")
|
||||
|
||||
file = {
|
||||
"id": "od-3",
|
||||
"name": "secret.txt",
|
||||
"file": {"mimeType": "text/plain"},
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert markdown is None
|
||||
assert error == "403 Forbidden"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Dropbox
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestDropboxContentExtraction:
|
||||
async def test_txt_file_returns_markdown(self):
|
||||
from app.connectors.dropbox.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_TXT_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "dbx-1",
|
||||
"name": "memo.txt",
|
||||
".tag": "file",
|
||||
"path_lower": "/memo.txt",
|
||||
}
|
||||
|
||||
markdown, metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert _TXT_CONTENT in markdown
|
||||
assert metadata["dropbox_file_id"] == "dbx-1"
|
||||
assert metadata["dropbox_file_name"] == "memo.txt"
|
||||
|
||||
async def test_csv_file_returns_markdown_table(self):
|
||||
from app.connectors.dropbox.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_CSV_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "dbx-2",
|
||||
"name": "data.csv",
|
||||
".tag": "file",
|
||||
"path_lower": "/data.csv",
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert "Alice" in markdown
|
||||
assert "|" in markdown
|
||||
|
||||
async def test_download_error_returns_error_message(self):
|
||||
from app.connectors.dropbox.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(return_value="Rate limited")
|
||||
|
||||
file = {
|
||||
"id": "dbx-3",
|
||||
"name": "big.txt",
|
||||
".tag": "file",
|
||||
"path_lower": "/big.txt",
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert markdown is None
|
||||
assert error == "Rate limited"
|
||||
|
|
@ -8,6 +8,10 @@ import pytest
|
|||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.dropbox_indexer import (
|
||||
_download_files_parallel,
|
||||
_index_full_scan,
|
||||
_index_selected_files,
|
||||
_index_with_delta_sync,
|
||||
index_dropbox_files,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
|
@ -234,3 +238,610 @@ async def test_heartbeat_fires_during_parallel_downloads(
|
|||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# D1-D2: _index_full_scan tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _folder_dict(name: str) -> dict:
|
||||
return {".tag": "folder", "name": name}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_dropbox_client, monkeypatch):
|
||||
"""Wire up mocks for _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_log_entry = MagicMock()
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
monkeypatch.setattr("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
from app.connectors.dropbox.file_types import should_skip_file as _skip
|
||||
|
||||
item_skip, unsup_ext = _skip(file)
|
||||
if item_skip:
|
||||
if unsup_ext:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
return True, "folder/non-downloadable"
|
||||
return skip_results.get(file.get("id", ""), (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
from app.services.page_limit_service import PageLimitService as _RealPLS
|
||||
|
||||
mock_page_limit_instance = MagicMock()
|
||||
mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999))
|
||||
mock_page_limit_instance.update_page_usage = AsyncMock()
|
||||
|
||||
class _MockPageLimitService:
|
||||
estimate_pages_from_metadata = staticmethod(
|
||||
_RealPLS.estimate_pages_from_metadata
|
||||
)
|
||||
|
||||
def __init__(self, session):
|
||||
self.get_page_usage = mock_page_limit_instance.get_page_usage
|
||||
self.update_page_usage = mock_page_limit_instance.update_page_usage
|
||||
|
||||
monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService)
|
||||
|
||||
return {
|
||||
"dropbox_client": mock_dropbox_client,
|
||||
"session": mock_session,
|
||||
"task_logger": mock_task_logger,
|
||||
"log_entry": mock_log_entry,
|
||||
"skip_results": skip_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_full_scan(mocks, monkeypatch, page_files, *, max_files=500):
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None)),
|
||||
)
|
||||
return await _index_full_scan(
|
||||
mocks["dropbox_client"],
|
||||
mocks["session"],
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"",
|
||||
"Root",
|
||||
mocks["task_logger"],
|
||||
mocks["log_entry"],
|
||||
max_files,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
||||
"""Skipped files excluded, renames counted as indexed, new files downloaded."""
|
||||
page_files = [
|
||||
_folder_dict("SubFolder"),
|
||||
_make_file_dict("skip1", "unchanged.txt"),
|
||||
_make_file_dict("rename1", "renamed.txt"),
|
||||
_make_file_dict("new1", "new1.txt"),
|
||||
_make_file_dict("new2", "new2.txt"),
|
||||
]
|
||||
|
||||
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
||||
full_scan_mocks["skip_results"]["rename1"] = (
|
||||
True,
|
||||
"File renamed: 'old' -> 'renamed.txt'",
|
||||
)
|
||||
|
||||
full_scan_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, _unsupported = await _run_full_scan(
|
||||
full_scan_mocks, monkeypatch, page_files
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 from batch
|
||||
assert skipped == 2 # 1 folder + 1 unchanged
|
||||
|
||||
call_args = full_scan_mocks["download_and_index_mock"].call_args
|
||||
call_files = call_args[0][2]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"new1", "new2"}
|
||||
|
||||
|
||||
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
||||
"""Only max_files non-folder items are considered."""
|
||||
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
||||
|
||||
full_scan_mocks["download_and_index_mock"].return_value = (3, 0)
|
||||
|
||||
await _run_full_scan(full_scan_mocks, monkeypatch, page_files, max_files=3)
|
||||
|
||||
call_files = full_scan_mocks["download_and_index_mock"].call_args[0][2]
|
||||
assert len(call_files) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# D3-D5: _index_selected_files tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selected_files_mocks(mock_dropbox_client, monkeypatch):
|
||||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, path):
|
||||
return get_file_results.get(path, (None, f"Not configured: {path}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_path", _fake_get_file)
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
return skip_results.get(file["id"], (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
from app.services.page_limit_service import PageLimitService as _RealPLS
|
||||
|
||||
mock_page_limit_instance = MagicMock()
|
||||
mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999))
|
||||
mock_page_limit_instance.update_page_usage = AsyncMock()
|
||||
|
||||
class _MockPageLimitService:
|
||||
estimate_pages_from_metadata = staticmethod(
|
||||
_RealPLS.estimate_pages_from_metadata
|
||||
)
|
||||
|
||||
def __init__(self, session):
|
||||
self.get_page_usage = mock_page_limit_instance.get_page_usage
|
||||
self.update_page_usage = mock_page_limit_instance.update_page_usage
|
||||
|
||||
monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService)
|
||||
|
||||
return {
|
||||
"dropbox_client": mock_dropbox_client,
|
||||
"session": mock_session,
|
||||
"get_file_results": get_file_results,
|
||||
"skip_results": skip_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_selected(mocks, file_tuples):
|
||||
return await _index_selected_files(
|
||||
mocks["dropbox_client"],
|
||||
mocks["session"],
|
||||
file_tuples,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_selected_files_single_file_indexed(selected_files_mocks):
|
||||
selected_files_mocks["get_file_results"]["/report.pdf"] = (
|
||||
_make_file_dict("f1", "report.pdf"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, _unsupported, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("/report.pdf", "report.pdf")],
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert errors == []
|
||||
|
||||
|
||||
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||
selected_files_mocks["get_file_results"]["/first.txt"] = (
|
||||
_make_file_dict("f1", "first.txt"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["get_file_results"]["/mid.txt"] = (None, "HTTP 404")
|
||||
selected_files_mocks["get_file_results"]["/third.txt"] = (
|
||||
_make_file_dict("f3", "third.txt"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, _unsupported, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[
|
||||
("/first.txt", "first.txt"),
|
||||
("/mid.txt", "mid.txt"),
|
||||
("/third.txt", "third.txt"),
|
||||
],
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert len(errors) == 1
|
||||
assert "mid.txt" in errors[0]
|
||||
|
||||
|
||||
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||
for path, fid, fname in [
|
||||
("/unchanged.txt", "s1", "unchanged.txt"),
|
||||
("/renamed.txt", "r1", "renamed.txt"),
|
||||
("/new1.txt", "n1", "new1.txt"),
|
||||
("/new2.txt", "n2", "new2.txt"),
|
||||
]:
|
||||
selected_files_mocks["get_file_results"][path] = (
|
||||
_make_file_dict(fid, fname),
|
||||
None,
|
||||
)
|
||||
|
||||
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
||||
selected_files_mocks["skip_results"]["r1"] = (
|
||||
True,
|
||||
"File renamed: 'old' -> 'renamed.txt'",
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, _unsupported, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[
|
||||
("/unchanged.txt", "unchanged.txt"),
|
||||
("/renamed.txt", "renamed.txt"),
|
||||
("/new1.txt", "new1.txt"),
|
||||
("/new2.txt", "new2.txt"),
|
||||
],
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 batch
|
||||
assert skipped == 1
|
||||
assert errors == []
|
||||
|
||||
mock = selected_files_mocks["download_and_index_mock"]
|
||||
call_files = mock.call_args[0][2]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E1-E4: _index_with_delta_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_delta_sync_deletions_call_remove_document(monkeypatch):
|
||||
"""E1: deleted entries are processed via _remove_document."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
entries = [
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "gone.txt",
|
||||
"path_lower": "/gone.txt",
|
||||
"id": "id:del1",
|
||||
},
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "also_gone.pdf",
|
||||
"path_lower": "/also_gone.pdf",
|
||||
"id": "id:del2",
|
||||
},
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=(entries, "new-cursor", None))
|
||||
|
||||
remove_calls: list[str] = []
|
||||
|
||||
async def _fake_remove(session, file_id, search_space_id):
|
||||
remove_calls.append(file_id)
|
||||
|
||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
monkeypatch.setattr(_mod, "_download_and_index", AsyncMock(return_value=(0, 0)))
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
_indexed, _skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"old-cursor",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert sorted(remove_calls) == ["id:del1", "id:del2"]
|
||||
assert cursor == "new-cursor"
|
||||
|
||||
|
||||
async def test_delta_sync_upserts_filtered_and_downloaded(monkeypatch):
|
||||
"""E2: modified/new file entries go through skip filter then download+index."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
entries = [
|
||||
_make_file_dict("mod1", "modified1.txt"),
|
||||
_make_file_dict("mod2", "modified2.txt"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=(entries, "cursor-v2", None))
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=(2, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_mock)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"cursor-v1",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert cursor == "cursor-v2"
|
||||
|
||||
downloaded_files = download_mock.call_args[0][2]
|
||||
assert len(downloaded_files) == 2
|
||||
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
|
||||
|
||||
|
||||
async def test_delta_sync_mix_deletions_and_upserts(monkeypatch):
|
||||
"""E3: deletions processed, then remaining upserts filtered and indexed."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
entries = [
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "removed.txt",
|
||||
"path_lower": "/removed.txt",
|
||||
"id": "id:del1",
|
||||
},
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "trashed.pdf",
|
||||
"path_lower": "/trashed.pdf",
|
||||
"id": "id:del2",
|
||||
},
|
||||
_make_file_dict("mod1", "updated.txt"),
|
||||
_make_file_dict("new1", "brandnew.docx"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=(entries, "final-cursor", None))
|
||||
|
||||
remove_calls: list[str] = []
|
||||
|
||||
async def _fake_remove(session, file_id, search_space_id):
|
||||
remove_calls.append(file_id)
|
||||
|
||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=(2, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_mock)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"old-cursor",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert sorted(remove_calls) == ["id:del1", "id:del2"]
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert cursor == "final-cursor"
|
||||
|
||||
downloaded_files = download_mock.call_args[0][2]
|
||||
assert {f["id"] for f in downloaded_files} == {"mod1", "new1"}
|
||||
|
||||
|
||||
async def test_delta_sync_returns_new_cursor(monkeypatch):
|
||||
"""E4: the new cursor from the API response is returned."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=([], "brand-new-cursor-xyz", None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_download_and_index", AsyncMock(return_value=(0, 0)))
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"old-cursor",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert cursor == "brand-new-cursor-xyz"
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# F1-F3: index_dropbox_files orchestrator tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator_mocks(monkeypatch):
|
||||
"""Wire up mocks for index_dropbox_files orchestrator tests."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.config = {"_token_encrypted": False}
|
||||
mock_connector.last_indexed_at = None
|
||||
mock_connector.enable_summary = True
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"get_connector_by_id",
|
||||
AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(_mod, "update_connector_last_indexed", AsyncMock())
|
||||
|
||||
full_scan_mock = AsyncMock(return_value=(5, 2, 0))
|
||||
monkeypatch.setattr(_mod, "_index_full_scan", full_scan_mock)
|
||||
|
||||
delta_sync_mock = AsyncMock(return_value=(3, 1, 0, "delta-cursor-new"))
|
||||
monkeypatch.setattr(_mod, "_index_with_delta_sync", delta_sync_mock)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_latest_cursor = AsyncMock(return_value=("latest-cursor-abc", None))
|
||||
monkeypatch.setattr(_mod, "DropboxClient", MagicMock(return_value=mock_client))
|
||||
|
||||
return {
|
||||
"connector": mock_connector,
|
||||
"full_scan_mock": full_scan_mock,
|
||||
"delta_sync_mock": delta_sync_mock,
|
||||
"mock_client": mock_client,
|
||||
}
|
||||
|
||||
|
||||
async def test_orchestrator_uses_delta_sync_when_cursor_and_last_indexed(
|
||||
orchestrator_mocks,
|
||||
):
|
||||
"""F1: with cursor + last_indexed_at + use_delta_sync, calls delta sync."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
connector = orchestrator_mocks["connector"]
|
||||
connector.config = {
|
||||
"_token_encrypted": False,
|
||||
"folder_cursors": {"/docs": "saved-cursor-123"},
|
||||
}
|
||||
connector.last_indexed_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
_indexed, _skipped, error, _unsupported = await index_dropbox_files(
|
||||
mock_session,
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
{
|
||||
"folders": [{"path": "/docs", "name": "Docs"}],
|
||||
"files": [],
|
||||
"indexing_options": {"use_delta_sync": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert error is None
|
||||
orchestrator_mocks["delta_sync_mock"].assert_called_once()
|
||||
orchestrator_mocks["full_scan_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_orchestrator_falls_back_to_full_scan_without_cursor(
|
||||
orchestrator_mocks,
|
||||
):
|
||||
"""F2: without cursor, falls back to full scan."""
|
||||
connector = orchestrator_mocks["connector"]
|
||||
connector.config = {"_token_encrypted": False}
|
||||
connector.last_indexed_at = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
_indexed, _skipped, error, _unsupported = await index_dropbox_files(
|
||||
mock_session,
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
{
|
||||
"folders": [{"path": "/docs", "name": "Docs"}],
|
||||
"files": [],
|
||||
"indexing_options": {"use_delta_sync": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert error is None
|
||||
orchestrator_mocks["full_scan_mock"].assert_called_once()
|
||||
orchestrator_mocks["delta_sync_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_orchestrator_persists_cursor_after_sync(orchestrator_mocks):
|
||||
"""F3: after sync, persists new cursor to connector config."""
|
||||
connector = orchestrator_mocks["connector"]
|
||||
connector.config = {"_token_encrypted": False}
|
||||
connector.last_indexed_at = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
await index_dropbox_files(
|
||||
mock_session,
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
{
|
||||
"folders": [{"path": "/docs", "name": "Docs"}],
|
||||
"files": [],
|
||||
},
|
||||
)
|
||||
|
||||
assert "folder_cursors" in connector.config
|
||||
assert connector.config["folder_cursors"]["/docs"] == "latest-cursor-abc"
|
||||
|
|
|
|||
|
|
@ -248,12 +248,33 @@ def _folder_dict(file_id: str, name: str) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def _make_page_limit_session(pages_used=0, pages_limit=999_999):
|
||||
"""Build a mock DB session that real PageLimitService can operate against."""
|
||||
|
||||
class _FakeUser:
|
||||
def __init__(self, pu, pl):
|
||||
self.pages_used = pu
|
||||
self.pages_limit = pl
|
||||
|
||||
fake_user = _FakeUser(pages_used, pages_limit)
|
||||
session = AsyncMock()
|
||||
|
||||
def _make_result(*_a, **_kw):
|
||||
r = MagicMock()
|
||||
r.first.return_value = (fake_user.pages_used, fake_user.pages_limit)
|
||||
r.unique.return_value.scalar_one_or_none.return_value = fake_user
|
||||
return r
|
||||
|
||||
session.execute = AsyncMock(side_effect=_make_result)
|
||||
return session, fake_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up all mocks needed to call _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
mock_connector = MagicMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
|
@ -345,7 +366,7 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
|||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
indexed, skipped = await _run_full_scan(full_scan_mocks)
|
||||
indexed, skipped, _unsupported = await _run_full_scan(full_scan_mocks)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 from batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
|
|
@ -472,11 +493,11 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped = await _index_with_delta_sync(
|
||||
indexed, skipped, _unsupported = await _index_with_delta_sync(
|
||||
MagicMock(),
|
||||
mock_session,
|
||||
MagicMock(),
|
||||
|
|
@ -512,7 +533,7 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
|
|||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
|
|
@ -568,7 +589,7 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
|
|||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
indexed, skipped, _unsup, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("f1", "report.pdf")],
|
||||
)
|
||||
|
|
@ -592,7 +613,7 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
|||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
indexed, skipped, _unsup, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
|
||||
)
|
||||
|
|
@ -626,7 +647,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
|||
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
indexed, skipped, _unsup, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[
|
||||
("s1", "unchanged.txt"),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
"""Unit tests for scan_folder() pure logic — Tier 2 TDD slices (S1-S4)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestScanFolder:
|
||||
"""S1-S4: scan_folder() with real tmp_path filesystem."""
|
||||
|
||||
def test_s1_single_md_file(self, tmp_path: Path):
|
||||
"""S1: scan_folder on a dir with one .md file returns correct entry."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
md = tmp_path / "note.md"
|
||||
md.write_text("# Hello")
|
||||
|
||||
results = scan_folder(str(tmp_path))
|
||||
|
||||
assert len(results) == 1
|
||||
entry = results[0]
|
||||
assert entry["relative_path"] == "note.md"
|
||||
assert entry["size"] > 0
|
||||
assert "modified_at" in entry
|
||||
assert entry["path"] == str(md)
|
||||
|
||||
def test_s2_extension_filter(self, tmp_path: Path):
|
||||
"""S2: file_extensions filter returns only matching files."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
(tmp_path / "a.md").write_text("md")
|
||||
(tmp_path / "b.txt").write_text("txt")
|
||||
(tmp_path / "c.pdf").write_bytes(b"%PDF")
|
||||
|
||||
results = scan_folder(str(tmp_path), file_extensions=[".md"])
|
||||
names = {r["relative_path"] for r in results}
|
||||
|
||||
assert names == {"a.md"}
|
||||
|
||||
def test_s3_exclude_patterns(self, tmp_path: Path):
|
||||
"""S3: exclude_patterns skips files inside excluded directories."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
(tmp_path / "good.md").write_text("good")
|
||||
nm = tmp_path / "node_modules"
|
||||
nm.mkdir()
|
||||
(nm / "dep.js").write_text("module")
|
||||
git = tmp_path / ".git"
|
||||
git.mkdir()
|
||||
(git / "config").write_text("gitconfig")
|
||||
|
||||
results = scan_folder(str(tmp_path), exclude_patterns=["node_modules", ".git"])
|
||||
names = {r["relative_path"] for r in results}
|
||||
|
||||
assert "good.md" in names
|
||||
assert not any("node_modules" in n for n in names)
|
||||
assert not any(".git" in n for n in names)
|
||||
|
||||
def test_s4_nested_dirs(self, tmp_path: Path):
|
||||
"""S4: nested subdirectories produce correct relative paths."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
daily = tmp_path / "notes" / "daily"
|
||||
daily.mkdir(parents=True)
|
||||
weekly = tmp_path / "notes" / "weekly"
|
||||
weekly.mkdir(parents=True)
|
||||
(daily / "today.md").write_text("today")
|
||||
(weekly / "review.md").write_text("review")
|
||||
(tmp_path / "root.txt").write_text("root")
|
||||
|
||||
results = scan_folder(str(tmp_path))
|
||||
paths = {r["relative_path"] for r in results}
|
||||
|
||||
assert "notes/daily/today.md" in paths or "notes\\daily\\today.md" in paths
|
||||
assert "notes/weekly/review.md" in paths or "notes\\weekly\\review.md" in paths
|
||||
assert "root.txt" in paths
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue