mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Merge pull request #1336 from MODSetter/dev
0.0.20: platform-wide chat, billing, connector, and desktop upgrades
This commit is contained in:
commit
2bd068dec6
622 changed files with 76757 additions and 11965 deletions
13
.github/workflows/desktop-release.yml
vendored
13
.github/workflows/desktop-release.yml
vendored
|
|
@ -136,6 +136,19 @@ jobs:
|
||||||
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
|
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
|
||||||
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
|
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
|
||||||
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
|
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
|
||||||
|
# macOS Developer ID signing + notarization. Only the macos-latest runner
|
||||||
|
# consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either
|
||||||
|
# a file path or a base64-encoded .p12 blob — electron-builder auto-detects.
|
||||||
|
CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }}
|
||||||
|
CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }}
|
||||||
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
# TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed.
|
||||||
|
# Surfaces the exact codesign / notarize commands electron-builder spawns,
|
||||||
|
# so we can see which subprocess hangs.
|
||||||
|
DEBUG: electron-builder,electron-osx-sign*,@electron/notarize*
|
||||||
|
ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true"
|
||||||
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
||||||
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
||||||
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
||||||
|
|
|
||||||
39
.github/workflows/obsidian-plugin-lint.yml
vendored
Normal file
39
.github/workflows/obsidian-plugin-lint.yml
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
name: Obsidian Plugin Lint
|
||||||
|
|
||||||
|
# Lints + type-checks + builds the Obsidian plugin on every push/PR that
|
||||||
|
# touches its sources. The official obsidian-sample-plugin template ships
|
||||||
|
# its own ESLint+esbuild setup; we run that here instead of folding the
|
||||||
|
# plugin into the monorepo's Biome-based code-quality.yml so the tooling
|
||||||
|
# stays aligned with what `obsidianmd/eslint-plugin-obsidianmd` checks
|
||||||
|
# against.
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["**"]
|
||||||
|
paths:
|
||||||
|
- "surfsense_obsidian/**"
|
||||||
|
- ".github/workflows/obsidian-plugin-lint.yml"
|
||||||
|
pull_request:
|
||||||
|
branches: ["**"]
|
||||||
|
paths:
|
||||||
|
- "surfsense_obsidian/**"
|
||||||
|
- ".github/workflows/obsidian-plugin-lint.yml"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: surfsense_obsidian
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: 22.x
|
||||||
|
cache: npm
|
||||||
|
cache-dependency-path: surfsense_obsidian/package-lock.json
|
||||||
|
|
||||||
|
- run: npm ci
|
||||||
|
- run: npm run lint
|
||||||
|
- run: npm run build
|
||||||
119
.github/workflows/release-obsidian-plugin.yml
vendored
Normal file
119
.github/workflows/release-obsidian-plugin.yml
vendored
Normal file
|
|
@ -0,0 +1,119 @@
|
||||||
|
name: Release Obsidian Plugin
|
||||||
|
|
||||||
|
# Tag format: `obsidian-v<version>` and `<version>` must match `surfsense_obsidian/manifest.json` exactly.
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "obsidian-v*"
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
publish:
|
||||||
|
description: "Publish to GitHub Releases"
|
||||||
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- never
|
||||||
|
- always
|
||||||
|
default: "never"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-release:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: surfsense_obsidian
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
# Need write access for the manifest/versions.json mirror commit
|
||||||
|
# back to main further down.
|
||||||
|
fetch-depth: 0
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: 22.x
|
||||||
|
cache: npm
|
||||||
|
cache-dependency-path: surfsense_obsidian/package-lock.json
|
||||||
|
|
||||||
|
- name: Resolve plugin version
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
manifest_version=$(node -p "require('./manifest.json').version")
|
||||||
|
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
# Manual runs derive the release version from manifest.json.
|
||||||
|
version="$manifest_version"
|
||||||
|
tag="obsidian-v$version"
|
||||||
|
else
|
||||||
|
tag="${GITHUB_REF_NAME}"
|
||||||
|
if [ -z "$tag" ] || [[ "$tag" != obsidian-v* ]]; then
|
||||||
|
echo "::error::Invalid tag '$tag'. Expected format: obsidian-v<version>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
version="${tag#obsidian-v}"
|
||||||
|
if [ "$version" != "$manifest_version" ]; then
|
||||||
|
echo "::error::Tag version '$version' does not match manifest version '$manifest_version'"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "version=$version" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Resolve publish mode
|
||||||
|
id: release_mode
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "push" ] || [ "${{ inputs.publish }}" = "always" ]; then
|
||||||
|
echo "should_publish=true" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "should_publish=false" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- run: npm ci
|
||||||
|
|
||||||
|
- run: npm run lint
|
||||||
|
|
||||||
|
- run: npm run build
|
||||||
|
|
||||||
|
- name: Verify build artifacts
|
||||||
|
run: |
|
||||||
|
for f in main.js manifest.json styles.css; do
|
||||||
|
test -f "$f" || (echo "::error::Missing release artifact: $f" && exit 1)
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Mirror manifest.json + versions.json to repo root
|
||||||
|
if: steps.release_mode.outputs.should_publish == 'true'
|
||||||
|
working-directory: ${{ github.workspace }}
|
||||||
|
run: |
|
||||||
|
cp surfsense_obsidian/manifest.json manifest.json
|
||||||
|
cp surfsense_obsidian/versions.json versions.json
|
||||||
|
if git diff --quiet manifest.json versions.json; then
|
||||||
|
echo "Root manifest/versions already up to date."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git add manifest.json versions.json
|
||||||
|
git commit -m "chore(obsidian-plugin): mirror manifest+versions for ${{ steps.version.outputs.tag }}"
|
||||||
|
# Push to the default branch so Obsidian can fetch raw files from HEAD.
|
||||||
|
if ! git push origin HEAD:${{ github.event.repository.default_branch }}; then
|
||||||
|
echo "::warning::Failed to push mirrored manifest/versions to default branch (likely branch protection). Continuing release."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Publish release under bare `manifest.json` version (no `obsidian-v` prefix) for BRAT/store compatibility.
|
||||||
|
# `make_latest: "false"` keeps the desktop app's `v*` release headlined since Obsidian and BRAT resolve plugins via getReleaseByTag, not the latest flag.
|
||||||
|
- name: Create GitHub release
|
||||||
|
if: steps.release_mode.outputs.should_publish == 'true'
|
||||||
|
uses: softprops/action-gh-release@v3
|
||||||
|
with:
|
||||||
|
tag_name: ${{ steps.version.outputs.version }}
|
||||||
|
name: SurfSense Obsidian Plugin ${{ steps.version.outputs.version }}
|
||||||
|
generate_release_notes: true
|
||||||
|
make_latest: "false"
|
||||||
|
files: |
|
||||||
|
surfsense_obsidian/main.js
|
||||||
|
surfsense_obsidian/manifest.json
|
||||||
|
surfsense_obsidian/styles.css
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -7,4 +7,5 @@ node_modules/
|
||||||
.pnpm-store
|
.pnpm-store
|
||||||
.DS_Store
|
.DS_Store
|
||||||
deepagents/
|
deepagents/
|
||||||
debug.log
|
debug.log
|
||||||
|
opencode/
|
||||||
12
README.es.md
12
README.es.md
|
|
@ -41,7 +41,7 @@ NotebookLM es una de las mejores y más útiles plataformas de IA que existen, p
|
||||||
- **Sin Dependencia de Proveedores** - Configura cualquier modelo LLM, de imagen, TTS y STT.
|
- **Sin Dependencia de Proveedores** - Configura cualquier modelo LLM, de imagen, TTS y STT.
|
||||||
- **25+ Fuentes de Datos Externas** - Agrega tus fuentes desde Google Drive, OneDrive, Dropbox, Notion y muchos otros servicios externos.
|
- **25+ Fuentes de Datos Externas** - Agrega tus fuentes desde Google Drive, OneDrive, Dropbox, Notion y muchos otros servicios externos.
|
||||||
- **Soporte Multijugador en Tiempo Real** - Trabaja fácilmente con los miembros de tu equipo en un notebook compartido.
|
- **Soporte Multijugador en Tiempo Real** - Trabaja fácilmente con los miembros de tu equipo en un notebook compartido.
|
||||||
- **Aplicación de Escritorio** - Obtén asistencia de IA en cualquier aplicación con Quick Assist, General Assist, Extreme Assist y sincronización de carpetas locales.
|
- **Aplicación de Escritorio** - Obtén asistencia de IA en cualquier aplicación con Quick Assist, General Assist, Screenshot Assist y sincronización de carpetas locales.
|
||||||
|
|
||||||
...y más por venir.
|
...y más por venir.
|
||||||
|
|
||||||
|
|
@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||||
|
|
||||||
- Aplicación de Escritorio — Extreme Assist
|
- Aplicación de Escritorio — Screenshot Assist
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif" alt="Screenshot Assist" /></p>
|
||||||
|
|
||||||
- Aplicación de Escritorio — Watch Local Folder
|
- Aplicación de Escritorio — Watch Local Folder
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ La aplicación de escritorio incluye estas potentes funciones:
|
||||||
|
|
||||||
- **General Assist** — Lanza SurfSense al instante desde cualquier aplicación con un atajo global.
|
- **General Assist** — Lanza SurfSense al instante desde cualquier aplicación con un atajo global.
|
||||||
- **Quick Assist** — Selecciona texto en cualquier lugar, luego pide a la IA que lo explique, reescriba o actúe sobre él.
|
- **Quick Assist** — Selecciona texto en cualquier lugar, luego pide a la IA que lo explique, reescriba o actúe sobre él.
|
||||||
- **Extreme Assist** — Obtén sugerencias de escritura en línea impulsadas por tu base de conocimiento mientras escribes en cualquier aplicación.
|
- **Screenshot Assist** — Selecciona una región de tu pantalla y adjúntala al chat para que las respuestas se basen en tu base de conocimiento.
|
||||||
- **Watch Local Folder** — Vigila una carpeta local y sincroniza automáticamente los cambios de archivos con tu base de conocimiento. **Pro tip:** Apúntalo a tu bóveda de Obsidian para mantener tus notas buscables en SurfSense.
|
- **Watch Local Folder** — Vigila una carpeta local y sincroniza automáticamente los cambios de archivos con tu base de conocimiento. **Pro tip:** Apúntalo a tu bóveda de Obsidian para mantener tus notas buscables en SurfSense.
|
||||||
|
|
||||||
Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tus respuestas siempre están basadas en tus propios datos.
|
Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tus respuestas siempre están basadas en tus propios datos.
|
||||||
|
|
@ -199,14 +199,14 @@ Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tu
|
||||||
| **Generación de Videos** | Resúmenes en video cinemáticos vía Veo 3 (solo Ultra) | Disponible (NotebookLM es mejor aquí, mejorando activamente) |
|
| **Generación de Videos** | Resúmenes en video cinemáticos vía Veo 3 (solo Ultra) | Disponible (NotebookLM es mejor aquí, mejorando activamente) |
|
||||||
| **Generación de Presentaciones** | Diapositivas más atractivas pero no editables | Crea presentaciones editables basadas en diapositivas |
|
| **Generación de Presentaciones** | Diapositivas más atractivas pero no editables | Crea presentaciones editables basadas en diapositivas |
|
||||||
| **Generación de Podcasts** | Resúmenes de audio con hosts e idiomas personalizables | Disponible con múltiples proveedores TTS (NotebookLM es mejor aquí, mejorando activamente) |
|
| **Generación de Podcasts** | Resúmenes de audio con hosts e idiomas personalizables | Disponible con múltiples proveedores TTS (NotebookLM es mejor aquí, mejorando activamente) |
|
||||||
| **Aplicación de Escritorio** | No | Aplicación nativa con General Assist, Quick Assist, Extreme Assist y sincronización de carpetas locales |
|
| **Aplicación de Escritorio** | No | Aplicación nativa con General Assist, Quick Assist, Screenshot Assist y sincronización de carpetas locales |
|
||||||
| **Extensión de Navegador** | No | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación |
|
| **Extensión de Navegador** | No | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Lista completa de Fuentes Externas</b></summary>
|
<summary><b>Lista completa de Fuentes Externas</b></summary>
|
||||||
<a id="fuentes-externas"></a>
|
<a id="fuentes-externas"></a>
|
||||||
|
|
||||||
Motores de Búsqueda (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Videos de YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, y más por venir.
|
Motores de Búsqueda (SearXNG, Tavily, LinkUp, Baidu Search) · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Videos de YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, y más por venir.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
12
README.hi.md
12
README.hi.md
|
|
@ -41,7 +41,7 @@ NotebookLM वहाँ उपलब्ध सबसे अच्छे और
|
||||||
- **कोई विक्रेता लॉक-इन नहीं** - किसी भी LLM, इमेज, TTS और STT मॉडल को कॉन्फ़िगर करें।
|
- **कोई विक्रेता लॉक-इन नहीं** - किसी भी LLM, इमेज, TTS और STT मॉडल को कॉन्फ़िगर करें।
|
||||||
- **25+ बाहरी डेटा स्रोत** - Google Drive, OneDrive, Dropbox, Notion और कई अन्य बाहरी सेवाओं से अपने स्रोत जोड़ें।
|
- **25+ बाहरी डेटा स्रोत** - Google Drive, OneDrive, Dropbox, Notion और कई अन्य बाहरी सेवाओं से अपने स्रोत जोड़ें।
|
||||||
- **रीयल-टाइम मल्टीप्लेयर सपोर्ट** - एक साझा notebook में अपनी टीम के सदस्यों के साथ आसानी से काम करें।
|
- **रीयल-टाइम मल्टीप्लेयर सपोर्ट** - एक साझा notebook में अपनी टीम के सदस्यों के साथ आसानी से काम करें।
|
||||||
- **डेस्कटॉप ऐप** - Quick Assist, General Assist, Extreme Assist और लोकल फ़ोल्डर सिंक के साथ किसी भी एप्लिकेशन में AI सहायता प्राप्त करें।
|
- **डेस्कटॉप ऐप** - Quick Assist, General Assist, Screenshot Assist और लोकल फ़ोल्डर सिंक के साथ किसी भी एप्लिकेशन में AI सहायता प्राप्त करें।
|
||||||
|
|
||||||
...और भी बहुत कुछ आने वाला है।
|
...और भी बहुत कुछ आने वाला है।
|
||||||
|
|
||||||
|
|
@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||||
|
|
||||||
- डेस्कटॉप ऐप — Extreme Assist
|
- डेस्कटॉप ऐप — Screenshot Assist
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif" alt="Screenshot Assist" /></p>
|
||||||
|
|
||||||
- डेस्कटॉप ऐप — Watch Local Folder
|
- डेस्कटॉप ऐप — Watch Local Folder
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ SurfSense एक डेस्कटॉप ऐप भी प्रदान क
|
||||||
|
|
||||||
- **General Assist** — एक ग्लोबल शॉर्टकट से किसी भी एप्लिकेशन से तुरंत SurfSense लॉन्च करें।
|
- **General Assist** — एक ग्लोबल शॉर्टकट से किसी भी एप्लिकेशन से तुरंत SurfSense लॉन्च करें।
|
||||||
- **Quick Assist** — कहीं भी टेक्स्ट चुनें, फिर AI से समझाने, फिर से लिखने या उस पर कार्रवाई करने को कहें।
|
- **Quick Assist** — कहीं भी टेक्स्ट चुनें, फिर AI से समझाने, फिर से लिखने या उस पर कार्रवाई करने को कहें।
|
||||||
- **Extreme Assist** — किसी भी ऐप में टाइप करते समय अपनी नॉलेज बेस से संचालित इनलाइन लेखन सुझाव प्राप्त करें।
|
- **Screenshot Assist** — स्क्रीन पर एक क्षेत्र चुनें और उसे चैट में जोड़ें, ताकि उत्तर आपकी नॉलेज बेस पर आधारित रहें।
|
||||||
- **Watch Local Folder** — एक लोकल फ़ोल्डर को वॉच करें और फ़ाइल परिवर्तनों को स्वचालित रूप से अपनी नॉलेज बेस में सिंक करें। **Pro tip:** इसे अपने Obsidian vault पर पॉइंट करें ताकि आपके नोट्स SurfSense में सर्च करने योग्य रहें।
|
- **Watch Local Folder** — एक लोकल फ़ोल्डर को वॉच करें और फ़ाइल परिवर्तनों को स्वचालित रूप से अपनी नॉलेज बेस में सिंक करें। **Pro tip:** इसे अपने Obsidian vault पर पॉइंट करें ताकि आपके नोट्स SurfSense में सर्च करने योग्य रहें।
|
||||||
|
|
||||||
सभी सुविधाएं आपके चुने हुए सर्च स्पेस पर काम करती हैं, ताकि आपके उत्तर हमेशा आपके अपने डेटा पर आधारित हों।
|
सभी सुविधाएं आपके चुने हुए सर्च स्पेस पर काम करती हैं, ताकि आपके उत्तर हमेशा आपके अपने डेटा पर आधारित हों।
|
||||||
|
|
@ -199,14 +199,14 @@ SurfSense एक डेस्कटॉप ऐप भी प्रदान क
|
||||||
| **वीडियो जनरेशन** | Veo 3 के माध्यम से सिनेमैटिक वीडियो ओवरव्यू (केवल Ultra) | उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) |
|
| **वीडियो जनरेशन** | Veo 3 के माध्यम से सिनेमैटिक वीडियो ओवरव्यू (केवल Ultra) | उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) |
|
||||||
| **प्रेजेंटेशन जनरेशन** | बेहतर दिखने वाली स्लाइड्स लेकिन संपादन योग्य नहीं | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं |
|
| **प्रेजेंटेशन जनरेशन** | बेहतर दिखने वाली स्लाइड्स लेकिन संपादन योग्य नहीं | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं |
|
||||||
| **पॉडकास्ट जनरेशन** | कस्टमाइज़ेबल होस्ट और भाषाओं के साथ ऑडियो ओवरव्यू | कई TTS प्रदाताओं के साथ उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) |
|
| **पॉडकास्ट जनरेशन** | कस्टमाइज़ेबल होस्ट और भाषाओं के साथ ऑडियो ओवरव्यू | कई TTS प्रदाताओं के साथ उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) |
|
||||||
| **डेस्कटॉप ऐप** | नहीं | General Assist, Quick Assist, Extreme Assist और लोकल फ़ोल्डर सिंक के साथ नेटिव ऐप |
|
| **डेस्कटॉप ऐप** | नहीं | General Assist, Quick Assist, Screenshot Assist और लोकल फ़ोल्डर सिंक के साथ नेटिव ऐप |
|
||||||
| **ब्राउज़र एक्सटेंशन** | नहीं | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित |
|
| **ब्राउज़र एक्सटेंशन** | नहीं | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>बाहरी स्रोतों की पूरी सूची</b></summary>
|
<summary><b>बाहरी स्रोतों की पूरी सूची</b></summary>
|
||||||
<a id="बाहरी-स्रोत"></a>
|
<a id="बाहरी-स्रोत"></a>
|
||||||
|
|
||||||
सर्च इंजन (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube वीडियो · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, और भी बहुत कुछ आने वाला है।
|
सर्च इंजन (SearXNG, Tavily, LinkUp, Baidu Search) · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube वीडियो · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, और भी बहुत कुछ आने वाला है।
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
12
README.md
12
README.md
|
|
@ -42,7 +42,7 @@ NotebookLM is one of the best and most useful AI platforms out there, but once y
|
||||||
- **25+ External Data Sources** - Add your sources from Google Drive, OneDrive, Dropbox, Notion, and many other external services.
|
- **25+ External Data Sources** - Add your sources from Google Drive, OneDrive, Dropbox, Notion, and many other external services.
|
||||||
- **Real-Time Multiplayer Support** - Work easily with your team members in a shared notebook.
|
- **Real-Time Multiplayer Support** - Work easily with your team members in a shared notebook.
|
||||||
- **AI File Sorting** - Automatically organize your documents into a smart folder hierarchy using AI-powered categorization by source, date, and topic.
|
- **AI File Sorting** - Automatically organize your documents into a smart folder hierarchy using AI-powered categorization by source, date, and topic.
|
||||||
- **Desktop App** - Get AI assistance in any application with Quick Assist, General Assist, Extreme Assist, and local folder sync.
|
- **Desktop App** - Get AI assistance in any application with Quick Assist, General Assist, Screenshot Assist, and local folder sync.
|
||||||
|
|
||||||
...and more to come.
|
...and more to come.
|
||||||
|
|
||||||
|
|
@ -85,9 +85,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||||
|
|
||||||
- Desktop App — Extreme Assist
|
- Desktop App — Screenshot Assist
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif" alt="Screenshot Assist" /></p>
|
||||||
|
|
||||||
- Desktop App — Watch Local Folder
|
- Desktop App — Watch Local Folder
|
||||||
|
|
||||||
|
|
@ -151,7 +151,7 @@ The desktop app includes these powerful features:
|
||||||
|
|
||||||
- **General Assist** — Launch SurfSense instantly from any application with a global shortcut.
|
- **General Assist** — Launch SurfSense instantly from any application with a global shortcut.
|
||||||
- **Quick Assist** — Select text anywhere, then ask AI to explain, rewrite, or act on it.
|
- **Quick Assist** — Select text anywhere, then ask AI to explain, rewrite, or act on it.
|
||||||
- **Extreme Assist** — Get inline writing suggestions powered by your knowledge base as you type in any app.
|
- **Screenshot Assist** — Select a region on your screen and attach it to chat so answers stay grounded in your knowledge base.
|
||||||
- **Watch Local Folder** — Watch a local folder and automatically sync file changes to your knowledge base. **Pro tip:** Point it at your Obsidian vault to keep your notes searchable in SurfSense.
|
- **Watch Local Folder** — Watch a local folder and automatically sync file changes to your knowledge base. **Pro tip:** Point it at your Obsidian vault to keep your notes searchable in SurfSense.
|
||||||
|
|
||||||
All features operate against your chosen search space, so your answers are always grounded in your own data.
|
All features operate against your chosen search space, so your answers are always grounded in your own data.
|
||||||
|
|
@ -201,14 +201,14 @@ All features operate against your chosen search space, so your answers are alway
|
||||||
| **Presentation Generation** | Better looking slides but not editable | Create editable, slide-based presentations |
|
| **Presentation Generation** | Better looking slides but not editable | Create editable, slide-based presentations |
|
||||||
| **Podcast Generation** | Audio Overviews with customizable hosts and languages | Available with multiple TTS providers (NotebookLM is better here, actively improving) |
|
| **Podcast Generation** | Audio Overviews with customizable hosts and languages | Available with multiple TTS providers (NotebookLM is better here, actively improving) |
|
||||||
| **AI File Sorting** | No | LLM-powered auto-categorization into source, date, category, and subcategory folders |
|
| **AI File Sorting** | No | LLM-powered auto-categorization into source, date, category, and subcategory folders |
|
||||||
| **Desktop App** | No | Native app with General Assist, Quick Assist, Extreme Assist, and local folder sync |
|
| **Desktop App** | No | Native app with General Assist, Quick Assist, Screenshot Assist, and local folder sync |
|
||||||
| **Browser Extension** | No | Cross-browser extension to save any webpage, including auth-protected pages |
|
| **Browser Extension** | No | Cross-browser extension to save any webpage, including auth-protected pages |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Full list of External Sources</b></summary>
|
<summary><b>Full list of External Sources</b></summary>
|
||||||
<a id="external-sources"></a>
|
<a id="external-sources"></a>
|
||||||
|
|
||||||
Search Engines (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube Videos · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, and more to come.
|
Search Engines (SearXNG, Tavily, LinkUp, Baidu Search) · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube Videos · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, and more to come.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ O NotebookLM é uma das melhores e mais úteis plataformas de IA disponíveis, m
|
||||||
- **Sem Dependência de Fornecedor** - Configure qualquer modelo LLM, de imagem, TTS e STT.
|
- **Sem Dependência de Fornecedor** - Configure qualquer modelo LLM, de imagem, TTS e STT.
|
||||||
- **25+ Fontes de Dados Externas** - Adicione suas fontes do Google Drive, OneDrive, Dropbox, Notion e muitos outros serviços externos.
|
- **25+ Fontes de Dados Externas** - Adicione suas fontes do Google Drive, OneDrive, Dropbox, Notion e muitos outros serviços externos.
|
||||||
- **Suporte Multiplayer em Tempo Real** - Trabalhe facilmente com os membros da sua equipe em um notebook compartilhado.
|
- **Suporte Multiplayer em Tempo Real** - Trabalhe facilmente com os membros da sua equipe em um notebook compartilhado.
|
||||||
- **Aplicativo Desktop** - Obtenha assistência de IA em qualquer aplicativo com Quick Assist, General Assist, Extreme Assist e sincronização de pastas locais.
|
- **Aplicativo Desktop** - Obtenha assistência de IA em qualquer aplicativo com Quick Assist, General Assist, Screenshot Assist e sincronização de pastas locais.
|
||||||
|
|
||||||
...e mais por vir.
|
...e mais por vir.
|
||||||
|
|
||||||
|
|
@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||||
|
|
||||||
- Aplicativo Desktop — Extreme Assist
|
- Aplicativo Desktop — Screenshot Assist
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif" alt="Screenshot Assist" /></p>
|
||||||
|
|
||||||
- Aplicativo Desktop — Watch Local Folder
|
- Aplicativo Desktop — Watch Local Folder
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ O aplicativo desktop inclui estes recursos poderosos:
|
||||||
|
|
||||||
- **General Assist** — Abra o SurfSense instantaneamente de qualquer aplicativo com um atalho global.
|
- **General Assist** — Abra o SurfSense instantaneamente de qualquer aplicativo com um atalho global.
|
||||||
- **Quick Assist** — Selecione texto em qualquer lugar, depois peça à IA para explicar, reescrever ou agir sobre ele.
|
- **Quick Assist** — Selecione texto em qualquer lugar, depois peça à IA para explicar, reescrever ou agir sobre ele.
|
||||||
- **Extreme Assist** — Receba sugestões de escrita em linha alimentadas pela sua base de conhecimento enquanto digita em qualquer aplicativo.
|
- **Screenshot Assist** — Selecione uma região da tela e anexe ao chat para respostas fundamentadas na sua base de conhecimento.
|
||||||
- **Watch Local Folder** — Monitore uma pasta local e sincronize automaticamente as alterações de arquivos com sua base de conhecimento. **Pro tip:** Aponte para seu cofre do Obsidian para manter suas notas pesquisáveis no SurfSense.
|
- **Watch Local Folder** — Monitore uma pasta local e sincronize automaticamente as alterações de arquivos com sua base de conhecimento. **Pro tip:** Aponte para seu cofre do Obsidian para manter suas notas pesquisáveis no SurfSense.
|
||||||
|
|
||||||
Todos os recursos operam no espaço de busca escolhido, para que suas respostas sejam sempre baseadas nos seus próprios dados.
|
Todos os recursos operam no espaço de busca escolhido, para que suas respostas sejam sempre baseadas nos seus próprios dados.
|
||||||
|
|
@ -199,14 +199,14 @@ Todos os recursos operam no espaço de busca escolhido, para que suas respostas
|
||||||
| **Geração de Vídeos** | Visões gerais cinemáticas via Veo 3 (apenas Ultra) | Disponível (NotebookLM é melhor aqui, melhorando ativamente) |
|
| **Geração de Vídeos** | Visões gerais cinemáticas via Veo 3 (apenas Ultra) | Disponível (NotebookLM é melhor aqui, melhorando ativamente) |
|
||||||
| **Geração de Apresentações** | Slides mais bonitos mas não editáveis | Cria apresentações editáveis baseadas em slides |
|
| **Geração de Apresentações** | Slides mais bonitos mas não editáveis | Cria apresentações editáveis baseadas em slides |
|
||||||
| **Geração de Podcasts** | Visões gerais em áudio com hosts e idiomas personalizáveis | Disponível com múltiplos provedores TTS (NotebookLM é melhor aqui, melhorando ativamente) |
|
| **Geração de Podcasts** | Visões gerais em áudio com hosts e idiomas personalizáveis | Disponível com múltiplos provedores TTS (NotebookLM é melhor aqui, melhorando ativamente) |
|
||||||
| **Aplicativo Desktop** | Não | Aplicativo nativo com General Assist, Quick Assist, Extreme Assist e sincronização de pastas locais |
|
| **Aplicativo Desktop** | Não | Aplicativo nativo com General Assist, Quick Assist, Screenshot Assist e sincronização de pastas locais |
|
||||||
| **Extensão de Navegador** | Não | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação |
|
| **Extensão de Navegador** | Não | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Lista completa de Fontes Externas</b></summary>
|
<summary><b>Lista completa de Fontes Externas</b></summary>
|
||||||
<a id="fontes-externas"></a>
|
<a id="fontes-externas"></a>
|
||||||
|
|
||||||
Mecanismos de Busca (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Vídeos do YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, e mais por vir.
|
Mecanismos de Busca (SearXNG, Tavily, LinkUp, Baidu Search) · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Vídeos do YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, e mais por vir.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ NotebookLM 是目前最好、最实用的 AI 平台之一,但当你开始经
|
||||||
- **无供应商锁定** - 配置任何 LLM、图像、TTS 和 STT 模型。
|
- **无供应商锁定** - 配置任何 LLM、图像、TTS 和 STT 模型。
|
||||||
- **25+ 外部数据源** - 从 Google Drive、OneDrive、Dropbox、Notion 和许多其他外部服务添加你的来源。
|
- **25+ 外部数据源** - 从 Google Drive、OneDrive、Dropbox、Notion 和许多其他外部服务添加你的来源。
|
||||||
- **实时多人协作支持** - 在共享笔记本中轻松与团队成员协作。
|
- **实时多人协作支持** - 在共享笔记本中轻松与团队成员协作。
|
||||||
- **桌面应用** - 通过 Quick Assist、General Assist、Extreme Assist 和本地文件夹同步在任何应用程序中获得 AI 助手。
|
- **桌面应用** - 通过 Quick Assist、General Assist、Screenshot Assist 和本地文件夹同步在任何应用程序中获得 AI 助手。
|
||||||
|
|
||||||
...更多功能即将推出。
|
...更多功能即将推出。
|
||||||
|
|
||||||
|
|
@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/quick_assist.gif" alt="Quick Assist" /></p>
|
||||||
|
|
||||||
- 桌面应用 — Extreme Assist
|
- 桌面应用 — Screenshot Assist
|
||||||
|
|
||||||
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif" alt="Extreme Assist" /></p>
|
<p align="center"><img src="surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif" alt="Screenshot Assist" /></p>
|
||||||
|
|
||||||
- 桌面应用 — Watch Local Folder
|
- 桌面应用 — Watch Local Folder
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ SurfSense 还提供桌面应用,将 AI 助手带到您计算机上的每个应
|
||||||
|
|
||||||
- **General Assist** — 通过全局快捷键从任何应用程序即时启动 SurfSense。
|
- **General Assist** — 通过全局快捷键从任何应用程序即时启动 SurfSense。
|
||||||
- **Quick Assist** — 在任何位置选中文本,然后让 AI 解释、改写或对其执行操作。
|
- **Quick Assist** — 在任何位置选中文本,然后让 AI 解释、改写或对其执行操作。
|
||||||
- **Extreme Assist** — 在任何应用中输入时,获得基于您知识库的内联写作建议。
|
- **Screenshot Assist** — 在屏幕上框选区域并附加到聊天,让回复基于您的知识库。
|
||||||
- **Watch Local Folder** — 监视本地文件夹,自动将文件更改同步到您的知识库。**Pro tip:** 将其指向您的 Obsidian vault,让笔记在 SurfSense 中随时可搜索。
|
- **Watch Local Folder** — 监视本地文件夹,自动将文件更改同步到您的知识库。**Pro tip:** 将其指向您的 Obsidian vault,让笔记在 SurfSense 中随时可搜索。
|
||||||
|
|
||||||
所有功能均基于您选择的搜索空间运行,确保回答始终以您自己的数据为依据。
|
所有功能均基于您选择的搜索空间运行,确保回答始终以您自己的数据为依据。
|
||||||
|
|
@ -199,14 +199,14 @@ SurfSense 还提供桌面应用,将 AI 助手带到您计算机上的每个应
|
||||||
| **视频生成** | 通过 Veo 3 的电影级视频概览(仅 Ultra) | 可用(NotebookLM 在此方面更好,正在积极改进) |
|
| **视频生成** | 通过 Veo 3 的电影级视频概览(仅 Ultra) | 可用(NotebookLM 在此方面更好,正在积极改进) |
|
||||||
| **演示文稿生成** | 更美观的幻灯片但不可编辑 | 创建可编辑的幻灯片式演示文稿 |
|
| **演示文稿生成** | 更美观的幻灯片但不可编辑 | 创建可编辑的幻灯片式演示文稿 |
|
||||||
| **播客生成** | 可自定义主持人和语言的音频概览 | 可用,支持多种 TTS 提供商(NotebookLM 在此方面更好,正在积极改进) |
|
| **播客生成** | 可自定义主持人和语言的音频概览 | 可用,支持多种 TTS 提供商(NotebookLM 在此方面更好,正在积极改进) |
|
||||||
| **桌面应用** | 否 | 原生应用,包含 General Assist、Quick Assist、Extreme Assist 和本地文件夹同步 |
|
| **桌面应用** | 否 | 原生应用,包含 General Assist、Quick Assist、Screenshot Assist 和本地文件夹同步 |
|
||||||
| **浏览器扩展** | 否 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 |
|
| **浏览器扩展** | 否 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>外部数据源完整列表</b></summary>
|
<summary><b>外部数据源完整列表</b></summary>
|
||||||
<a id="外部数据源"></a>
|
<a id="外部数据源"></a>
|
||||||
|
|
||||||
搜索引擎(Tavily、LinkUp)· SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube 视频 · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian,更多即将推出。
|
搜索引擎(SearXNG、Tavily、LinkUp、Baidu Search)· Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube 视频 · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian,更多即将推出。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
||||||
0.0.19
|
0.0.20
|
||||||
|
|
|
||||||
|
|
@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
|
||||||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||||
|
|
||||||
# Premium token purchases ($1 per 1M tokens for premium-tier models)
|
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
|
||||||
|
# credit; premium turns debit the actual per-call provider cost
|
||||||
|
# reported by LiteLLM, so cheap and expensive models bill proportionally)
|
||||||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||||
|
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||||
|
|
@ -305,6 +308,24 @@ STT_SERVICE=local/base
|
||||||
# Advanced (optional)
|
# Advanced (optional)
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# New-chat agent feature flags
|
||||||
|
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||||
|
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||||
|
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||||
|
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||||
|
SURFSENSE_ENABLE_BUSY_MUTEX=true
|
||||||
|
SURFSENSE_ENABLE_SKILLS=true
|
||||||
|
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
|
||||||
|
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
|
||||||
|
SURFSENSE_ENABLE_ACTION_LOG=true
|
||||||
|
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
||||||
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||||
|
|
||||||
# Periodic connector sync interval (default: 5m)
|
# Periodic connector sync interval (default: 5m)
|
||||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||||
|
|
||||||
|
|
@ -315,9 +336,24 @@ STT_SERVICE=local/base
|
||||||
# Pages limit per user for ETL (default: unlimited)
|
# Pages limit per user for ETL (default: unlimited)
|
||||||
# PAGES_LIMIT=500
|
# PAGES_LIMIT=500
|
||||||
|
|
||||||
# Premium token quota per registered user (default: 5M)
|
# Premium credit quota per registered user, in micro-USD (default: $5).
|
||||||
# Only applies to models with billing_tier=premium in global_llm_config.yaml
|
# Premium turns are debited at the actual per-call provider cost reported
|
||||||
# PREMIUM_TOKEN_LIMIT=5000000
|
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||||
|
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||||
|
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
|
||||||
|
|
||||||
|
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||||
|
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
|
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
|
||||||
|
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||||
|
|
||||||
|
# Per-podcast reservation for the podcast Celery task ($0.20 default).
|
||||||
|
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||||
|
|
||||||
|
# Per-video-presentation reservation for the video Celery task ($1.00 default).
|
||||||
|
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||||
|
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
# No-login (anonymous) mode — public users can chat without an account
|
# No-login (anonymous) mode — public users can chat without an account
|
||||||
# Set TRUE to enable /free pages and anonymous chat API
|
# Set TRUE to enable /free pages and anonymous chat API
|
||||||
|
|
|
||||||
123
docker/docker-compose.deps-only.yml
Normal file
123
docker/docker-compose.deps-only.yml
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
# =============================================================================
|
||||||
|
# SurfSense — Dependencies only (no backend / frontend / Celery images)
|
||||||
|
# =============================================================================
|
||||||
|
# Postgres, Redis, SearXNG, pgAdmin, Zero — run API + Next + Celery on the host.
|
||||||
|
# Celery is not Dockerized here: use `uv run` from surfsense_backend/ (no extra
|
||||||
|
# backend image build just for workers).
|
||||||
|
#
|
||||||
|
# From repo root (SurfSense/):
|
||||||
|
# docker compose -f docker/docker-compose.deps-only.yml up -d
|
||||||
|
#
|
||||||
|
# Compose variable substitution uses `docker/.env` (copy from .env.example).
|
||||||
|
# Bind mounts use ./postgresql.conf and ./searxng in this directory.
|
||||||
|
#
|
||||||
|
# Local Celery (from surfsense_backend/, after Redis is up):
|
||||||
|
# uv run celery -A celery_worker.celery_app worker --loglevel=info --concurrency=1 --pool=solo --queues=surfsense,surfsense.connectors
|
||||||
|
# uv run celery -A celery_worker.celery_app beat --loglevel=info
|
||||||
|
#
|
||||||
|
# Host setup:
|
||||||
|
# - Backend .env: DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
|
||||||
|
# - Backend .env: SEARXNG_DEFAULT_HOST=http://localhost:${SEARXNG_PORT:-8888}
|
||||||
|
# - Backend .env: CELERY_BROKER_URL / REDIS_APP_URL → redis://localhost:6379/0
|
||||||
|
# - Web .env: NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:${ZERO_CACHE_PORT:-4848}
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
name: surfsense-deps
|
||||||
|
|
||||||
|
services:
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg17
|
||||||
|
ports:
|
||||||
|
- "${POSTGRES_PORT:-5432}:5432"
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=${DB_USER:-postgres}
|
||||||
|
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
|
||||||
|
- POSTGRES_DB=${DB_NAME:-surfsense}
|
||||||
|
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
pgadmin:
|
||||||
|
image: dpage/pgadmin4
|
||||||
|
ports:
|
||||||
|
- "${PGADMIN_PORT:-5050}:80"
|
||||||
|
environment:
|
||||||
|
- PGADMIN_DEFAULT_EMAIL=${PGADMIN_DEFAULT_EMAIL:-admin@surfsense.com}
|
||||||
|
- PGADMIN_DEFAULT_PASSWORD=${PGADMIN_DEFAULT_PASSWORD:-surfsense}
|
||||||
|
volumes:
|
||||||
|
- pgadmin_data:/var/lib/pgadmin
|
||||||
|
depends_on:
|
||||||
|
- db
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:8-alpine
|
||||||
|
ports:
|
||||||
|
- "${REDIS_PORT:-6379}:6379"
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
command: redis-server --appendonly yes
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
searxng:
|
||||||
|
image: searxng/searxng:2026.3.13-3c1f68c59
|
||||||
|
ports:
|
||||||
|
- "${SEARXNG_PORT:-8888}:8080"
|
||||||
|
volumes:
|
||||||
|
- ./searxng:/etc/searxng
|
||||||
|
environment:
|
||||||
|
- SEARXNG_SECRET=${SEARXNG_SECRET:-surfsense-searxng-secret}
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "wget", "--spider", "-q", "http://localhost:8080/healthz"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
zero-cache:
|
||||||
|
image: rocicorp/zero:0.26.2
|
||||||
|
ports:
|
||||||
|
- "${ZERO_CACHE_PORT:-4848}:4848"
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
environment:
|
||||||
|
- ZERO_UPSTREAM_DB=postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@db:5432/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}
|
||||||
|
- ZERO_CVR_DB=postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@db:5432/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}
|
||||||
|
- ZERO_CHANGE_DB=postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@db:5432/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}
|
||||||
|
- ZERO_REPLICA_FILE=/data/zero.db
|
||||||
|
- ZERO_ADMIN_PASSWORD=${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
|
||||||
|
- ZERO_APP_PUBLICATIONS=${ZERO_APP_PUBLICATIONS:-zero_publication}
|
||||||
|
- ZERO_NUM_SYNC_WORKERS=${ZERO_NUM_SYNC_WORKERS:-4}
|
||||||
|
- ZERO_UPSTREAM_MAX_CONNS=${ZERO_UPSTREAM_MAX_CONNS:-20}
|
||||||
|
- ZERO_CVR_MAX_CONNS=${ZERO_CVR_MAX_CONNS:-30}
|
||||||
|
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://host.docker.internal:3000/api/zero/query}
|
||||||
|
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://host.docker.internal:3000/api/zero/mutate}
|
||||||
|
volumes:
|
||||||
|
- zero_cache_data:/data
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
name: surfsense-deps-postgres
|
||||||
|
pgadmin_data:
|
||||||
|
name: surfsense-deps-pgadmin
|
||||||
|
redis_data:
|
||||||
|
name: surfsense-deps-redis
|
||||||
|
zero_cache_data:
|
||||||
|
name: surfsense-deps-zero-cache
|
||||||
10
manifest.json
Normal file
10
manifest.json
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"id": "surfsense-obsidian",
|
||||||
|
"name": "SurfSense",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"minAppVersion": "1.5.4",
|
||||||
|
"description": "Turn your vault into a searchable second brain with SurfSense.",
|
||||||
|
"author": "SurfSense",
|
||||||
|
"authorUrl": "https://www.surfsense.com",
|
||||||
|
"isDesktopOnly": false
|
||||||
|
}
|
||||||
|
|
@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
|
||||||
# Set FALSE to disable new checkout session creation temporarily
|
# Set FALSE to disable new checkout session creation temporarily
|
||||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||||
|
|
||||||
# Premium token purchases via Stripe (for premium-tier model usage)
|
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||||
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
|
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||||
|
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||||
|
# per-call provider cost reported by LiteLLM.
|
||||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||||
STRIPE_TOKENS_PER_UNIT=1000000
|
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||||
|
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
|
||||||
|
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||||
|
|
||||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||||
|
|
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
||||||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||||
PAGES_LIMIT=500
|
PAGES_LIMIT=500
|
||||||
|
|
||||||
# Premium token quota per registered user (default: 3,000,000)
|
# Premium credit quota per registered user, in micro-USD
|
||||||
# Applies only to models with billing_tier=premium in global_llm_config.yaml
|
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||||
PREMIUM_TOKEN_LIMIT=3000000
|
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||||
|
# models bill proportionally. Applies only to models with
|
||||||
|
# billing_tier=premium in global_llm_config.yaml.
|
||||||
|
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||||
|
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
|
||||||
|
# PREMIUM_TOKEN_LIMIT=5000000
|
||||||
|
|
||||||
|
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||||
|
# stream_new_chat estimates an upper-bound cost from the model's
|
||||||
|
# litellm-published per-token rates × the config's quota_reserve_tokens
|
||||||
|
# and clamps to this value so a misconfigured model can't lock the
|
||||||
|
# user's whole balance on one call. Default $1.00.
|
||||||
|
QUOTA_MAX_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
|
# Per-image reservation (in micro-USD) for the POST /image-generations
|
||||||
|
# endpoint. Bypassed for free configs. Default $0.05.
|
||||||
|
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||||
|
|
||||||
|
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
|
||||||
|
# Single envelope covers one transcript-generation LLM call. Default $0.20.
|
||||||
|
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||||
|
|
||||||
|
# Per-video-presentation reservation (in micro-USD) used by the video
|
||||||
|
# presentation Celery task. Covers worst-case fan-out of N slide-scene
|
||||||
|
# generations + refines. Default $1.00. NOTE: tasks using the override
|
||||||
|
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||||
|
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
# No-login (anonymous) mode — allows public users to chat without an account
|
# No-login (anonymous) mode — allows public users to chat without an account
|
||||||
# Set TRUE to enable /free pages and anonymous chat API
|
# Set TRUE to enable /free pages and anonymous chat API
|
||||||
|
|
@ -239,8 +269,58 @@ LLAMA_CLOUD_API_KEY=llx-nnn
|
||||||
# DAYTONA_TARGET=us
|
# DAYTONA_TARGET=us
|
||||||
# DAYTONA_SNAPSHOT_ID=
|
# DAYTONA_SNAPSHOT_ID=
|
||||||
|
|
||||||
|
# Desktop local filesystem mode (chat file tools run against a local folder root)
|
||||||
|
# ENABLE_DESKTOP_LOCAL_FILESYSTEM=FALSE
|
||||||
|
|
||||||
# OPTIONAL: Add these for LangSmith Observability
|
# OPTIONAL: Add these for LangSmith Observability
|
||||||
LANGSMITH_TRACING=true
|
LANGSMITH_TRACING=true
|
||||||
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||||
LANGSMITH_API_KEY=lsv2_pt_.....
|
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||||
LANGSMITH_PROJECT=surfsense
|
LANGSMITH_PROJECT=surfsense
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# OPTIONAL: New-chat agent feature flags
|
||||||
|
# =============================================================================
|
||||||
|
# Master kill-switch — when true, every flag below is forced OFF.
|
||||||
|
# SURFSENSE_DISABLE_NEW_AGENT_STACK=false
|
||||||
|
|
||||||
|
# Agent quality
|
||||||
|
# SURFSENSE_ENABLE_CONTEXT_EDITING=false
|
||||||
|
# SURFSENSE_ENABLE_COMPACTION_V2=false
|
||||||
|
# SURFSENSE_ENABLE_RETRY_AFTER=false
|
||||||
|
# SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
# SURFSENSE_ENABLE_MODEL_CALL_LIMIT=false
|
||||||
|
# SURFSENSE_ENABLE_TOOL_CALL_LIMIT=false
|
||||||
|
# SURFSENSE_ENABLE_TOOL_CALL_REPAIR=false
|
||||||
|
# SURFSENSE_ENABLE_DOOM_LOOP=false # leave OFF until UI handles permission='doom_loop'
|
||||||
|
|
||||||
|
# Safety
|
||||||
|
# SURFSENSE_ENABLE_PERMISSION=false
|
||||||
|
# SURFSENSE_ENABLE_BUSY_MUTEX=false
|
||||||
|
# SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||||
|
|
||||||
|
# Observability — OTel (also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||||
|
# SURFSENSE_ENABLE_OTEL=false
|
||||||
|
|
||||||
|
# Skills + subagents
|
||||||
|
# SURFSENSE_ENABLE_SKILLS=false
|
||||||
|
# SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false
|
||||||
|
# SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false
|
||||||
|
|
||||||
|
# Snapshot / revert
|
||||||
|
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||||
|
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||||
|
|
||||||
|
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
||||||
|
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
||||||
|
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
||||||
|
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
||||||
|
# Schema migrations 135/136 ship unconditionally because they are
|
||||||
|
# forward-compatible.
|
||||||
|
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
||||||
|
|
||||||
|
# Plugins
|
||||||
|
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||||
|
# Comma-separated allowlist of plugin entry-point names
|
||||||
|
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
||||||
|
|
|
||||||
|
|
@ -79,40 +79,44 @@ def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
|
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
|
||||||
|
# opened one via context.begin_transaction(), but the driver still errors
|
||||||
|
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
for tbl in sorted(TABLES_WITH_FULL_IDENTITY):
|
||||||
|
_terminate_blocked_pids(conn, tbl)
|
||||||
|
conn.execute(sa.text(f'LOCK TABLE "{tbl}" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
for tbl in sorted(TABLES_WITH_FULL_IDENTITY):
|
for tbl in TABLES_WITH_FULL_IDENTITY:
|
||||||
_terminate_blocked_pids(conn, tbl)
|
conn.execute(sa.text(f'ALTER TABLE "{tbl}" REPLICA IDENTITY DEFAULT'))
|
||||||
conn.execute(sa.text(f'LOCK TABLE "{tbl}" IN ACCESS EXCLUSIVE MODE'))
|
|
||||||
|
|
||||||
for tbl in TABLES_WITH_FULL_IDENTITY:
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
conn.execute(sa.text(f'ALTER TABLE "{tbl}" REPLICA IDENTITY DEFAULT'))
|
|
||||||
|
|
||||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
has_zero_ver = conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = 'documents' AND column_name = '_0_version'"
|
||||||
|
)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
has_zero_ver = conn.execute(
|
cols = DOCUMENT_COLS + (['"_0_version"'] if has_zero_ver else [])
|
||||||
sa.text(
|
col_list = ", ".join(cols)
|
||||||
"SELECT 1 FROM information_schema.columns "
|
|
||||||
"WHERE table_name = 'documents' AND column_name = '_0_version'"
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
cols = DOCUMENT_COLS + (['"_0_version"'] if has_zero_ver else [])
|
|
||||||
col_list = ", ".join(cols)
|
|
||||||
|
|
||||||
conn.execute(
|
|
||||||
sa.text(
|
|
||||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
|
||||||
f"notifications, "
|
|
||||||
f"documents ({col_list}), "
|
|
||||||
f"folders, "
|
|
||||||
f"search_source_connectors, "
|
|
||||||
f"new_chat_messages, "
|
|
||||||
f"chat_comments, "
|
|
||||||
f"chat_session_state"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
revision: str = "121"
|
revision: str = "121"
|
||||||
|
|
@ -23,16 +21,30 @@ depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.add_column(
|
# Idempotent: column(s) may already exist after a failed run or manual DDL.
|
||||||
"user",
|
op.execute(
|
||||||
sa.Column("memory_md", sa.Text(), nullable=True, server_default=""),
|
"""
|
||||||
)
|
DO $$
|
||||||
op.add_column(
|
BEGIN
|
||||||
"searchspaces",
|
IF NOT EXISTS (
|
||||||
sa.Column("shared_memory_md", sa.Text(), nullable=True, server_default=""),
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public' AND table_name = 'user'
|
||||||
|
AND column_name = 'memory_md'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE "user" ADD COLUMN memory_md TEXT DEFAULT '';
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public' AND table_name = 'searchspaces'
|
||||||
|
AND column_name = 'shared_memory_md'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE searchspaces ADD COLUMN shared_memory_md TEXT DEFAULT '';
|
||||||
|
END IF;
|
||||||
|
END$$;
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_column("searchspaces", "shared_memory_md")
|
op.execute("ALTER TABLE searchspaces DROP COLUMN IF EXISTS shared_memory_md")
|
||||||
op.drop_column("user", "memory_md")
|
op.execute('ALTER TABLE "user" DROP COLUMN IF EXISTS memory_md')
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
"""129_obsidian_plugin_vault_identity
|
||||||
|
|
||||||
|
Revision ID: 129
|
||||||
|
Revises: 128
|
||||||
|
Create Date: 2026-04-21
|
||||||
|
|
||||||
|
Locks down vault identity for the Obsidian plugin connector:
|
||||||
|
|
||||||
|
- Deactivates pre-plugin OBSIDIAN_CONNECTOR rows.
|
||||||
|
- Partial unique index on ``(user_id, (config->>'vault_id'))`` for the
|
||||||
|
``/obsidian/connect`` upsert fast path.
|
||||||
|
- Partial unique index on ``(user_id, (config->>'vault_fingerprint'))``
|
||||||
|
so two devices observing the same vault content can never produce
|
||||||
|
two connector rows. Collisions are caught by the route handler and
|
||||||
|
routed through the merge path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "129"
|
||||||
|
down_revision: str | None = "128"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
UPDATE search_source_connectors
|
||||||
|
SET
|
||||||
|
is_indexable = false,
|
||||||
|
periodic_indexing_enabled = false,
|
||||||
|
next_scheduled_at = NULL,
|
||||||
|
config = COALESCE(config, '{}'::json)::jsonb
|
||||||
|
|| jsonb_build_object(
|
||||||
|
'legacy', true,
|
||||||
|
'deactivated_at', to_char(
|
||||||
|
now() AT TIME ZONE 'UTC',
|
||||||
|
'YYYY-MM-DD"T"HH24:MI:SS"Z"'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
|
||||||
|
AND COALESCE((config::jsonb)->>'source', '') <> 'plugin'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS
|
||||||
|
search_source_connectors_obsidian_plugin_vault_uniq
|
||||||
|
ON search_source_connectors (user_id, ((config->>'vault_id')))
|
||||||
|
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
|
||||||
|
AND config->>'source' = 'plugin'
|
||||||
|
AND config->>'vault_id' IS NOT NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS
|
||||||
|
search_source_connectors_obsidian_plugin_fingerprint_uniq
|
||||||
|
ON search_source_connectors (user_id, ((config->>'vault_fingerprint')))
|
||||||
|
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
|
||||||
|
AND config->>'source' = 'plugin'
|
||||||
|
AND config->>'vault_fingerprint' IS NOT NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"DROP INDEX IF EXISTS "
|
||||||
|
"search_source_connectors_obsidian_plugin_fingerprint_uniq"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"DROP INDEX IF EXISTS search_source_connectors_obsidian_plugin_vault_uniq"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
UPDATE search_source_connectors
|
||||||
|
SET config = (config::jsonb - 'legacy' - 'deactivated_at')::json
|
||||||
|
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
|
||||||
|
AND (config::jsonb) ? 'legacy'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""130_add_agent_action_log
|
||||||
|
|
||||||
|
Revision ID: 130
|
||||||
|
Revises: 129
|
||||||
|
Create Date: 2026-04-28
|
||||||
|
|
||||||
|
Adds the append-only ``agent_action_log`` table that
|
||||||
|
:class:`ActionLogMiddleware` writes to after every tool call. Each row
|
||||||
|
optionally carries a ``reverse_descriptor`` payload used by
|
||||||
|
``POST /api/threads/{thread_id}/revert/{action_id}`` to undo the action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "130"
|
||||||
|
down_revision: str | None = "129"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||||
|
sa.Column(
|
||||||
|
"thread_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("user.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"search_space_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column("turn_id", sa.String(length=64), nullable=True, index=True),
|
||||||
|
sa.Column("message_id", sa.String(length=128), nullable=True, index=True),
|
||||||
|
sa.Column("tool_name", sa.String(length=255), nullable=False, index=True),
|
||||||
|
sa.Column("args", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("result_id", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"reversible",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("false"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"reverse_descriptor",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("error", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"reverse_of",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_thread_created",
|
||||||
|
"agent_action_log",
|
||||||
|
["thread_id", "created_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_agent_action_log_thread_created", table_name="agent_action_log")
|
||||||
|
op.drop_table("agent_action_log")
|
||||||
119
surfsense_backend/alembic/versions/131_add_document_revisions.py
Normal file
119
surfsense_backend/alembic/versions/131_add_document_revisions.py
Normal file
|
|
@ -0,0 +1,119 @@
|
||||||
|
"""131_add_document_revisions
|
||||||
|
|
||||||
|
Revision ID: 131
|
||||||
|
Revises: 130
|
||||||
|
Create Date: 2026-04-28
|
||||||
|
|
||||||
|
Adds two snapshot tables that back the per-action revert flow:
|
||||||
|
|
||||||
|
* ``document_revisions``: pre-mutation snapshot of NOTE/FILE/EXTENSION docs.
|
||||||
|
* ``folder_revisions``: pre-mutation snapshot of folder mkdir/move/delete.
|
||||||
|
|
||||||
|
Both are written by :class:`KnowledgeBasePersistenceMiddleware` ahead of
|
||||||
|
state-changing tool calls and consumed by ``revert_service.revert_action``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "131"
|
||||||
|
down_revision: str | None = "130"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"document_revisions",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||||
|
sa.Column(
|
||||||
|
"document_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("documents.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"search_space_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column("content_before", sa.Text(), nullable=True),
|
||||||
|
sa.Column("title_before", sa.String(), nullable=True),
|
||||||
|
sa.Column("folder_id_before", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"chunks_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"metadata_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_by_turn_id", sa.String(length=64), nullable=True, index=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"agent_action_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"folder_revisions",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||||
|
sa.Column(
|
||||||
|
"folder_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("folders.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"search_space_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column("name_before", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("parent_id_before", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("position_before", sa.String(length=50), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_by_turn_id", sa.String(length=64), nullable=True, index=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"agent_action_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("folder_revisions")
|
||||||
|
op.drop_table("document_revisions")
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""132_add_agent_permission_rules
|
||||||
|
|
||||||
|
Revision ID: 132
|
||||||
|
Revises: 131
|
||||||
|
Create Date: 2026-04-28
|
||||||
|
|
||||||
|
Adds the persistent ``agent_permission_rules`` table consumed by
|
||||||
|
:class:`PermissionMiddleware` at agent build time. Rules can be scoped
|
||||||
|
at search-space (``user_id`` / ``thread_id`` NULL), user-wide
|
||||||
|
(``user_id`` set, ``thread_id`` NULL), or per-thread (``thread_id`` set).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "132"
|
||||||
|
down_revision: str | None = "131"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"agent_permission_rules",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||||
|
sa.Column(
|
||||||
|
"search_space_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"thread_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column("permission", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"pattern",
|
||||||
|
sa.String(length=255),
|
||||||
|
nullable=False,
|
||||||
|
server_default="*",
|
||||||
|
),
|
||||||
|
sa.Column("action", sa.String(length=16), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"search_space_id",
|
||||||
|
"user_id",
|
||||||
|
"thread_id",
|
||||||
|
"permission",
|
||||||
|
"pattern",
|
||||||
|
"action",
|
||||||
|
name="uq_agent_permission_rules_scope",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("agent_permission_rules")
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""133_drop_documents_content_hash_unique
|
||||||
|
|
||||||
|
Revision ID: 133
|
||||||
|
Revises: 132
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Drop the global UNIQUE constraint on ``documents.content_hash`` so the
|
||||||
|
new-chat agent's ``write_file`` flow can persist legitimate file copies
|
||||||
|
(two paths, identical content) without hitting a constraint that mirrors
|
||||||
|
no real filesystem semantic.
|
||||||
|
|
||||||
|
Path uniqueness still lives on ``documents.unique_identifier_hash`` (per
|
||||||
|
search space), which is the right invariant — exactly like an inode at a
|
||||||
|
given path on a POSIX filesystem.
|
||||||
|
|
||||||
|
The non-unique INDEX on ``content_hash`` is preserved so connector
|
||||||
|
indexers' "have we seen this content before?" lookup
|
||||||
|
(:func:`app.tasks.document_processors.base.check_duplicate_document`,
|
||||||
|
which already uses ``.scalars().first()`` and is therefore tolerant of
|
||||||
|
duplicates) stays cheap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "133"
|
||||||
|
down_revision: str | None = "132"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _existing_constraint_names(bind, table: str) -> set[str]:
|
||||||
|
inspector = inspect(bind)
|
||||||
|
return {c["name"] for c in inspector.get_unique_constraints(table)}
|
||||||
|
|
||||||
|
|
||||||
|
def _existing_index_names(bind, table: str) -> set[str]:
|
||||||
|
inspector = inspect(bind)
|
||||||
|
return {i["name"] for i in inspector.get_indexes(table)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# Both the named UniqueConstraint (added in revision 8) and the
|
||||||
|
# implicit-unique-index variant SQLAlchemy may emit need draining.
|
||||||
|
constraints = _existing_constraint_names(bind, "documents")
|
||||||
|
if "uq_documents_content_hash" in constraints:
|
||||||
|
op.drop_constraint("uq_documents_content_hash", "documents", type_="unique")
|
||||||
|
|
||||||
|
indexes = _existing_index_names(bind, "documents")
|
||||||
|
# Some Postgres versions surface the unique constraint via a unique
|
||||||
|
# index of the same name; check for that too.
|
||||||
|
for idx_name in ("uq_documents_content_hash",):
|
||||||
|
if idx_name in indexes:
|
||||||
|
op.drop_index(idx_name, table_name="documents")
|
||||||
|
|
||||||
|
# Ensure the non-unique index is present for fast lookups.
|
||||||
|
if "ix_documents_content_hash" not in indexes:
|
||||||
|
op.create_index(
|
||||||
|
"ix_documents_content_hash",
|
||||||
|
"documents",
|
||||||
|
["content_hash"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# Re-applying UNIQUE is destructive: there may now be legitimate
|
||||||
|
# duplicates (e.g. two NOTE documents that share content because the
|
||||||
|
# user explicitly copied one to a new path). To avoid the migration
|
||||||
|
# silently deleting user data, we keep only the lowest-id row per
|
||||||
|
# content_hash — same strategy revision 8 used when first introducing
|
||||||
|
# the constraint.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM documents
|
||||||
|
WHERE id NOT IN (
|
||||||
|
SELECT MIN(id)
|
||||||
|
FROM documents
|
||||||
|
GROUP BY content_hash
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
indexes = _existing_index_names(bind, "documents")
|
||||||
|
if "ix_documents_content_hash" in indexes:
|
||||||
|
op.drop_index("ix_documents_content_hash", table_name="documents")
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
"ix_documents_content_hash",
|
||||||
|
"documents",
|
||||||
|
["content_hash"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_unique_constraint(
|
||||||
|
"uq_documents_content_hash", "documents", ["content_hash"]
|
||||||
|
)
|
||||||
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""134_relax_revision_fks
|
||||||
|
|
||||||
|
Revision ID: 134
|
||||||
|
Revises: 133
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so
|
||||||
|
revisions survive the deletes they describe.
|
||||||
|
|
||||||
|
Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE
|
||||||
|
hard-deleting a document via the ``rm`` tool (and likewise a
|
||||||
|
``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE``
|
||||||
|
the snapshot row is wiped at the exact moment we need it most — revert
|
||||||
|
then has nothing to read and the operation becomes irreversible.
|
||||||
|
|
||||||
|
Migration:
|
||||||
|
|
||||||
|
* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK
|
||||||
|
``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``.
|
||||||
|
* ``folder_revisions.folder_id``: same treatment.
|
||||||
|
|
||||||
|
The ``search_space_id`` FK on both tables is left unchanged (still
|
||||||
|
``ON DELETE CASCADE``). When a search space is deleted, all documents,
|
||||||
|
folders, AND their revisions go together — that's the correct teardown
|
||||||
|
story.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "134"
|
||||||
|
down_revision: str | None = "133"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _fk_name(bind, table: str, column: str) -> str | None:
|
||||||
|
"""Return the (single) FK constraint name on ``table.column``, if any."""
|
||||||
|
inspector = inspect(bind)
|
||||||
|
for fk in inspector.get_foreign_keys(table):
|
||||||
|
cols = fk.get("constrained_columns") or []
|
||||||
|
if cols == [column]:
|
||||||
|
return fk.get("name")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# --- document_revisions.document_id -> nullable + SET NULL ---------------
|
||||||
|
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"document_revisions",
|
||||||
|
"document_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"document_revisions_document_id_fkey",
|
||||||
|
"document_revisions",
|
||||||
|
"documents",
|
||||||
|
["document_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- folder_revisions.folder_id -> nullable + SET NULL -------------------
|
||||||
|
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"folder_revisions",
|
||||||
|
"folder_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"folder_revisions_folder_id_fkey",
|
||||||
|
"folder_revisions",
|
||||||
|
"folders",
|
||||||
|
["folder_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# Reinstating NOT NULL + CASCADE requires draining orphan rows first
|
||||||
|
# (any revision whose parent doc/folder has already been deleted).
|
||||||
|
op.execute("DELETE FROM document_revisions WHERE document_id IS NULL")
|
||||||
|
op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL")
|
||||||
|
|
||||||
|
# --- document_revisions.document_id -> NOT NULL + CASCADE ---------------
|
||||||
|
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"document_revisions",
|
||||||
|
"document_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"document_revisions_document_id_fkey",
|
||||||
|
"document_revisions",
|
||||||
|
"documents",
|
||||||
|
["document_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- folder_revisions.folder_id -> NOT NULL + CASCADE -------------------
|
||||||
|
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"folder_revisions",
|
||||||
|
"folder_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"folder_revisions_folder_id_fkey",
|
||||||
|
"folder_revisions",
|
||||||
|
"folders",
|
||||||
|
["folder_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
"""135_action_log_correlation_ids
|
||||||
|
|
||||||
|
Revision ID: 135
|
||||||
|
Revises: 134
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Action-log correlation-id cleanup.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes
|
||||||
|
the LangChain ``tool_call.id`` into that column today (see
|
||||||
|
``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch``
|
||||||
|
joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from
|
||||||
|
``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was
|
||||||
|
never persisted.
|
||||||
|
|
||||||
|
This migration introduces two new, correctly-named columns:
|
||||||
|
|
||||||
|
* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held)
|
||||||
|
* ``chat_turn_id`` (the per-turn correlation id from
|
||||||
|
``configurable.turn_id`` — used by the per-turn ``revert-turn`` route).
|
||||||
|
|
||||||
|
Backfill copies the current ``turn_id`` values into ``tool_call_id`` so
|
||||||
|
existing joins keep working. The old ``turn_id`` column is left in place
|
||||||
|
for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware``
|
||||||
|
keeps writing it (= ``tool_call_id``) for the same reason.
|
||||||
|
|
||||||
|
Indexes
|
||||||
|
-------
|
||||||
|
|
||||||
|
* ``ix_agent_action_log_tool_call_id`` — required by
|
||||||
|
``_find_action_ids_batch`` (was on ``turn_id``).
|
||||||
|
* ``ix_agent_action_log_chat_turn_id`` — required by the
|
||||||
|
``revert-turn/{chat_turn_id}`` query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "135"
|
||||||
|
down_revision: str | None = "134"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("tool_call_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("chat_turn_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_tool_call_id",
|
||||||
|
"agent_action_log",
|
||||||
|
["tool_call_id"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_chat_turn_id",
|
||||||
|
"agent_action_log",
|
||||||
|
["chat_turn_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log")
|
||||||
|
op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log")
|
||||||
|
op.drop_column("agent_action_log", "chat_turn_id")
|
||||||
|
op.drop_column("agent_action_log", "tool_call_id")
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""136_new_chat_message_turn_id
|
||||||
|
|
||||||
|
Revision ID: 136
|
||||||
|
Revises: 135
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Persist the per-turn correlation id on each chat message.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
LangGraph's checkpointer stores user-provided ``configurable.turn_id``
|
||||||
|
in checkpoint metadata (see
|
||||||
|
``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To
|
||||||
|
support edit-from-arbitrary-position, the regenerate route needs to map
|
||||||
|
a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without
|
||||||
|
this column the mapping doesn't exist anywhere, so regenerate would
|
||||||
|
have to hardcode the "last 2 messages" rewind heuristic.
|
||||||
|
|
||||||
|
This migration adds a nullable ``turn_id`` column to ``new_chat_messages``
|
||||||
|
plus an index. Legacy rows have NULL — the regenerate route degrades
|
||||||
|
gracefully to the reload-last-two heuristic for those.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "136"
|
||||||
|
down_revision: str | None = "135"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"new_chat_messages",
|
||||||
|
sa.Column("turn_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_new_chat_messages_turn_id",
|
||||||
|
"new_chat_messages",
|
||||||
|
["turn_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages")
|
||||||
|
op.drop_column("new_chat_messages", "turn_id")
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""137_unique_reverse_of_in_action_log
|
||||||
|
|
||||||
|
Revision ID: 137
|
||||||
|
Revises: 136
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Protect ``agent_action_log.reverse_of`` against double inserts. Two
|
||||||
|
concurrent revert calls (single-action route + the per-turn batch
|
||||||
|
route, or two batch routes racing) both pass the
|
||||||
|
``_was_already_reverted`` SELECT and both insert their own
|
||||||
|
``_revert:*`` rows. The application-level idempotency check is racy
|
||||||
|
because there's no DB constraint backing it.
|
||||||
|
|
||||||
|
This migration adds a partial unique index on ``reverse_of`` (PostgreSQL
|
||||||
|
``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises
|
||||||
|
``IntegrityError`` and the route can translate it to ``"already_reverted"``
|
||||||
|
deterministically.
|
||||||
|
|
||||||
|
The plain ``UniqueConstraint`` flavour can't be used because most
|
||||||
|
existing rows have ``reverse_of = NULL`` (only revert rows fill it),
|
||||||
|
and Postgres does treat NULL as distinct in unique indexes — but a
|
||||||
|
partial index is the cleanest expression of intent and works even on
|
||||||
|
older Postgres releases that distinguish NULL handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "137"
|
||||||
|
down_revision: str | None = "136"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
_INDEX_NAME = "ux_agent_action_log_reverse_of"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Defensively de-dup any pre-existing double-revert rows before
|
||||||
|
# adding the unique index. Keeps the OLDEST row (smallest id) and
|
||||||
|
# NULLs out the duplicates' ``reverse_of`` so they survive as audit
|
||||||
|
# trail but no longer claim to be the canonical revert. We do NOT
|
||||||
|
# delete them — operators can still inspect them via /actions.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
WITH dups AS (
|
||||||
|
SELECT id,
|
||||||
|
reverse_of,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY reverse_of ORDER BY id ASC
|
||||||
|
) AS rn
|
||||||
|
FROM agent_action_log
|
||||||
|
WHERE reverse_of IS NOT NULL
|
||||||
|
)
|
||||||
|
UPDATE agent_action_log
|
||||||
|
SET reverse_of = NULL
|
||||||
|
WHERE id IN (SELECT id FROM dups WHERE rn > 1)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
_INDEX_NAME,
|
||||||
|
"agent_action_log",
|
||||||
|
["reverse_of"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where="reverse_of IS NOT NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(_INDEX_NAME, table_name="agent_action_log")
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""138_add_thread_auto_model_pinning_fields
|
||||||
|
|
||||||
|
Revision ID: 138
|
||||||
|
Revises: 137
|
||||||
|
Create Date: 2026-04-30
|
||||||
|
|
||||||
|
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||||
|
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||||
|
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||||
|
|
||||||
|
The column is unindexed: all reads are by new_chat_threads.id (primary key),
|
||||||
|
so a secondary index would be dead write amplification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "138"
|
||||||
|
down_revision: str | None = "137"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE new_chat_threads "
|
||||||
|
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop any shape the thread row may be carrying. The extra columns and
|
||||||
|
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
|
||||||
|
# makes each statement a safe no-op on the lean shape.
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
|
||||||
|
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
|
||||||
|
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
"""add user table to zero_publication with column list
|
||||||
|
|
||||||
|
Adds the "user" table to zero_publication with a column-list publication
|
||||||
|
so that only the 5 fields driving the live usage meters are replicated
|
||||||
|
through WAL -> zero-cache -> browser IndexedDB:
|
||||||
|
|
||||||
|
id, pages_limit, pages_used,
|
||||||
|
premium_tokens_limit, premium_tokens_used
|
||||||
|
|
||||||
|
Sensitive columns (hashed_password, email, oauth_account, display_name,
|
||||||
|
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
|
||||||
|
included in the publication, so they never enter WAL replication.
|
||||||
|
|
||||||
|
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
|
||||||
|
(it is already DEFAULT today since "user" was never in the
|
||||||
|
TABLES_WITH_FULL_IDENTITY list of migration 117).
|
||||||
|
|
||||||
|
IMPORTANT - before AND after running this migration:
|
||||||
|
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||||
|
2. Run: alembic upgrade head
|
||||||
|
3. Delete / reset the zero-cache data volume
|
||||||
|
4. Restart zero-cache (it will do a fresh initial sync)
|
||||||
|
|
||||||
|
Revision ID: 139
|
||||||
|
Revises: 138
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "139"
|
||||||
|
down_revision: str | None = "138"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
PUBLICATION_NAME = "zero_publication"
|
||||||
|
|
||||||
|
# Document column list as left by migration 117. Must match exactly.
|
||||||
|
DOCUMENT_COLS = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"document_type",
|
||||||
|
"search_space_id",
|
||||||
|
"folder_id",
|
||||||
|
"created_by_id",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
|
||||||
|
# Buy Tokens content). Keep this list narrow on purpose: anything added
|
||||||
|
# here flows into WAL and IndexedDB for every connected browser.
|
||||||
|
USER_COLS = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
"premium_tokens_used",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT pg_terminate_backend(l.pid) "
|
||||||
|
"FROM pg_locks l "
|
||||||
|
"JOIN pg_class c ON c.oid = l.relation "
|
||||||
|
"WHERE c.relname = :tbl "
|
||||||
|
" AND l.pid != pg_backend_pid()"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_zero_version(conn, table: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl(
|
||||||
|
documents_has_zero_ver: bool, user_has_zero_ver: bool
|
||||||
|
) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
user_col_list = ", ".join(user_cols)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state, "
|
||||||
|
f'"user" ({user_col_list})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
|
||||||
|
# opened one via context.begin_transaction(), but the driver still errors
|
||||||
|
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
|
||||||
|
# migration 117, so this is already DEFAULT. Re-assert anyway so
|
||||||
|
# the column-list publication stays valid (DEFAULT identity only
|
||||||
|
# requires the PK to be in the column list).
|
||||||
|
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||||
|
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))
|
||||||
|
|
@ -0,0 +1,291 @@
|
||||||
|
"""rename premium token columns to credit micros and add cost_micros to token_usage
|
||||||
|
|
||||||
|
Migrates the premium quota system from a flat token counter to a USD-cost
|
||||||
|
based credit system, where 1 credit = 1 micro-USD ($0.000001).
|
||||||
|
|
||||||
|
Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
|
||||||
|
price means every existing value is already correct in the new unit, no
|
||||||
|
data transformation needed):
|
||||||
|
|
||||||
|
user.premium_tokens_limit -> premium_credit_micros_limit
|
||||||
|
user.premium_tokens_used -> premium_credit_micros_used
|
||||||
|
user.premium_tokens_reserved -> premium_credit_micros_reserved
|
||||||
|
|
||||||
|
premium_token_purchases.tokens_granted -> credit_micros_granted
|
||||||
|
|
||||||
|
New column for cost auditing per turn:
|
||||||
|
|
||||||
|
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
|
||||||
|
|
||||||
|
The "user" table is in zero_publication's column list (added in 139), so
|
||||||
|
this migration must drop and recreate the publication with the renamed
|
||||||
|
column names, otherwise zero-cache will replicate stale column names and
|
||||||
|
the FE Zero schema will fail to bind.
|
||||||
|
|
||||||
|
IMPORTANT - before AND after running this migration:
|
||||||
|
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||||
|
2. Run: alembic upgrade head
|
||||||
|
3. Delete / reset the zero-cache data volume
|
||||||
|
4. Restart zero-cache (it will do a fresh initial sync)
|
||||||
|
|
||||||
|
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
|
||||||
|
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
|
||||||
|
column-not-found errors from a stale catalog snapshot.
|
||||||
|
|
||||||
|
Revision ID: 140
|
||||||
|
Revises: 139
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "140"
|
||||||
|
down_revision: str | None = "139"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
PUBLICATION_NAME = "zero_publication"
|
||||||
|
|
||||||
|
# Replicates 139's document column list verbatim — must stay in sync.
|
||||||
|
DOCUMENT_COLS = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"document_type",
|
||||||
|
"search_space_id",
|
||||||
|
"folder_id",
|
||||||
|
"created_by_id",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Same five live-meter fields as 139, with the renamed column names.
|
||||||
|
USER_COLS = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_credit_micros_limit",
|
||||||
|
"premium_credit_micros_used",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT pg_terminate_backend(l.pid) "
|
||||||
|
"FROM pg_locks l "
|
||||||
|
"JOIN pg_class c ON c.oid = l.relation "
|
||||||
|
"WHERE c.relname = :tbl "
|
||||||
|
" AND l.pid != pg_backend_pid()"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_zero_version(conn, table: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _column_exists(conn, table: str, column: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = :col"
|
||||||
|
),
|
||||||
|
{"tbl": table, "col": column},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl(
|
||||||
|
user_cols: list[str],
|
||||||
|
*,
|
||||||
|
documents_has_zero_ver: bool,
|
||||||
|
user_has_zero_ver: bool,
|
||||||
|
) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
user_col_list_with_meta = user_cols + (
|
||||||
|
['"_0_version"'] if user_has_zero_ver else []
|
||||||
|
)
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
user_col_list = ", ".join(user_col_list_with_meta)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state, "
|
||||||
|
f'"user" ({user_col_list})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
|
||||||
|
# dev environments are safe.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if not _column_exists(conn, "token_usage", "cost_micros"):
|
||||||
|
op.add_column(
|
||||||
|
"token_usage",
|
||||||
|
sa.Column(
|
||||||
|
"cost_micros",
|
||||||
|
sa.BigInteger(),
|
||||||
|
nullable=False,
|
||||||
|
server_default="0",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if _column_exists(
|
||||||
|
conn, "premium_token_purchases", "tokens_granted"
|
||||||
|
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
|
||||||
|
op.alter_column(
|
||||||
|
"premium_token_purchases",
|
||||||
|
"tokens_granted",
|
||||||
|
new_column_name="credit_micros_granted",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
|
||||||
|
#
|
||||||
|
# We must drop the publication first (it references the old column
|
||||||
|
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
|
||||||
|
# in a transaction block; alembic's outer transaction already holds
|
||||||
|
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
|
||||||
|
# publications require at least the PK to be in the column list,
|
||||||
|
# which is true for both the old and new shape.
|
||||||
|
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||||
|
|
||||||
|
# Drop the publication BEFORE renaming columns, otherwise Postgres
|
||||||
|
# rejects the rename: "cannot drop column ... referenced by
|
||||||
|
# publication".
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
for old, new in (
|
||||||
|
("premium_tokens_limit", "premium_credit_micros_limit"),
|
||||||
|
("premium_tokens_used", "premium_credit_micros_used"),
|
||||||
|
("premium_tokens_reserved", "premium_credit_micros_reserved"),
|
||||||
|
):
|
||||||
|
if _column_exists(conn, "user", old) and not _column_exists(
|
||||||
|
conn, "user", new
|
||||||
|
):
|
||||||
|
op.alter_column("user", old, new_column_name=new)
|
||||||
|
|
||||||
|
# Update the server_default on the renamed limit column so newly
|
||||||
|
# inserted users get $5 of credit (== 5_000_000 micros) by
|
||||||
|
# default. Existing rows are unaffected.
|
||||||
|
op.alter_column(
|
||||||
|
"user",
|
||||||
|
"premium_credit_micros_limit",
|
||||||
|
server_default="5000000",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate the publication with the new column names.
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
_build_publication_ddl(
|
||||||
|
USER_COLS,
|
||||||
|
documents_has_zero_ver=documents_has_zero_ver,
|
||||||
|
user_has_zero_ver=user_has_zero_ver,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Revert the rename and drop ``cost_micros``.
|
||||||
|
|
||||||
|
Mirrors ``upgrade``: drop the publication, rename columns back, drop
|
||||||
|
the new column, recreate the publication with the old column list.
|
||||||
|
Same zero-cache stop/reset runbook applies in reverse.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
for new, old in (
|
||||||
|
("premium_credit_micros_limit", "premium_tokens_limit"),
|
||||||
|
("premium_credit_micros_used", "premium_tokens_used"),
|
||||||
|
("premium_credit_micros_reserved", "premium_tokens_reserved"),
|
||||||
|
):
|
||||||
|
if _column_exists(conn, "user", new) and not _column_exists(
|
||||||
|
conn, "user", old
|
||||||
|
):
|
||||||
|
op.alter_column("user", new, new_column_name=old)
|
||||||
|
|
||||||
|
op.alter_column(
|
||||||
|
"user",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
server_default="5000000",
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_user_cols = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
"premium_tokens_used",
|
||||||
|
]
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
_build_publication_ddl(
|
||||||
|
legacy_user_cols,
|
||||||
|
documents_has_zero_ver=documents_has_zero_ver,
|
||||||
|
user_has_zero_ver=user_has_zero_ver,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if _column_exists(
|
||||||
|
conn, "premium_token_purchases", "credit_micros_granted"
|
||||||
|
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
|
||||||
|
op.alter_column(
|
||||||
|
"premium_token_purchases",
|
||||||
|
"credit_micros_granted",
|
||||||
|
new_column_name="tokens_granted",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _column_exists(conn, "token_usage", "cost_micros"):
|
||||||
|
op.drop_column("token_usage", "cost_micros")
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
"""Agent-based vision autocomplete with scoped filesystem exploration."""
|
|
||||||
|
|
||||||
from app.agents.autocomplete.autocomplete_agent import (
|
|
||||||
create_autocomplete_agent,
|
|
||||||
stream_autocomplete_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"create_autocomplete_agent",
|
|
||||||
"stream_autocomplete_agent",
|
|
||||||
]
|
|
||||||
|
|
@ -28,13 +28,76 @@ from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.document_xml import build_document_xml
|
||||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
build_scoped_filesystem,
|
|
||||||
search_knowledge_base,
|
search_knowledge_base,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
)
|
||||||
|
from app.db import shielded_async_session
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deepagents.backends.utils import create_file_data
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
|
||||||
|
def create_file_data(content: str) -> dict[str, Any]:
|
||||||
|
return {"content": content.split("\n")}
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_autocomplete_filesystem(
|
||||||
|
*,
|
||||||
|
documents: Any,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> tuple[dict[str, Any], dict[int, str]]:
|
||||||
|
"""Build a ``state['files']``-shaped dict from KB search results.
|
||||||
|
|
||||||
|
This is the autocomplete-specific replacement for the previous
|
||||||
|
``build_scoped_filesystem`` helper. It uses the canonical path resolver
|
||||||
|
so paths line up with the rest of the system, including collision
|
||||||
|
suffixes for duplicate titles.
|
||||||
|
"""
|
||||||
|
files: dict[str, Any] = {}
|
||||||
|
doc_id_to_path: dict[int, str] = {}
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return files, doc_id_to_path
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, search_space_id)
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
|
if not isinstance(document, dict):
|
||||||
|
continue
|
||||||
|
meta = document.get("document") or {}
|
||||||
|
doc_id = meta.get("id")
|
||||||
|
if not isinstance(doc_id, int):
|
||||||
|
continue
|
||||||
|
title = str(meta.get("title") or "untitled")
|
||||||
|
folder_id = meta.get("folder_id")
|
||||||
|
path = doc_to_virtual_path(
|
||||||
|
doc_id=doc_id, title=title, folder_id=folder_id, index=index
|
||||||
|
)
|
||||||
|
chunk_ids = document.get("matched_chunk_ids") or []
|
||||||
|
try:
|
||||||
|
matched_set = {int(c) for c in chunk_ids}
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
matched_set = set()
|
||||||
|
xml = build_document_xml(document, matched_chunk_ids=matched_set)
|
||||||
|
files[path] = create_file_data(xml)
|
||||||
|
doc_id_to_path[doc_id] = path
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
# Ensure the synthetic /documents folder is visible even when empty.
|
||||||
|
files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data(""))
|
||||||
|
|
||||||
|
return files, doc_id_to_path
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
KB_TOP_K = 10
|
KB_TOP_K = 10
|
||||||
|
|
@ -174,7 +237,7 @@ async def precompute_kb_filesystem(
|
||||||
if not search_results:
|
if not search_results:
|
||||||
return _KBResult()
|
return _KBResult()
|
||||||
|
|
||||||
new_files, _ = await build_scoped_filesystem(
|
new_files, _ = await _build_autocomplete_filesystem(
|
||||||
documents=search_results,
|
documents=search_results,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
)
|
)
|
||||||
|
|
@ -215,13 +278,12 @@ async def precompute_kb_filesystem(
|
||||||
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
||||||
"""Filesystem middleware for autocomplete — read-only exploration only.
|
"""Filesystem middleware for autocomplete — read-only exploration only.
|
||||||
|
|
||||||
Strips ``save_document`` (permanent KB persistence) and passes
|
Passes ``search_space_id=None`` so the new persistence pipeline is
|
||||||
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
|
bypassed; the autocomplete flow only reads, never commits to Postgres.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(search_space_id=None, created_by_id=None)
|
super().__init__(search_space_id=None, created_by_id=None)
|
||||||
self.tools = [t for t in self.tools if t.name != "save_document"]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
|
||||||
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
||||||
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
||||||
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
||||||
summarisation, prompt-caching, etc.).
|
summarisation, etc.). Prompt caching is configured at LLM-build time via
|
||||||
|
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
|
||||||
|
than as a middleware.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -23,37 +25,110 @@ from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_v
|
||||||
from deepagents.backends import StateBackend
|
from deepagents.backends import StateBackend
|
||||||
from deepagents.graph import BASE_AGENT_PROMPT
|
from deepagents.graph import BASE_AGENT_PROMPT
|
||||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||||
|
from deepagents.middleware.skills import SkillsMiddleware
|
||||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import TodoListMiddleware
|
from langchain.agents.middleware import (
|
||||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
LLMToolSelectorMiddleware,
|
||||||
|
ModelCallLimitMiddleware,
|
||||||
|
ModelFallbackMiddleware,
|
||||||
|
TodoListMiddleware,
|
||||||
|
ToolCallLimitMiddleware,
|
||||||
|
)
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||||
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.middleware import (
|
from app.agents.new_chat.middleware import (
|
||||||
|
ActionLogMiddleware,
|
||||||
|
AnonymousDocumentMiddleware,
|
||||||
|
BusyMutexMiddleware,
|
||||||
|
ClearToolUsesEdit,
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
KnowledgeBaseSearchMiddleware,
|
DoomLoopMiddleware,
|
||||||
|
FileIntentMiddleware,
|
||||||
|
KnowledgeBasePersistenceMiddleware,
|
||||||
|
KnowledgePriorityMiddleware,
|
||||||
|
KnowledgeTreeMiddleware,
|
||||||
MemoryInjectionMiddleware,
|
MemoryInjectionMiddleware,
|
||||||
|
NoopInjectionMiddleware,
|
||||||
|
OtelSpanMiddleware,
|
||||||
|
PermissionMiddleware,
|
||||||
|
RetryAfterMiddleware,
|
||||||
|
SpillingContextEditingMiddleware,
|
||||||
|
SpillToBackendEdit,
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
|
ToolCallNameRepairMiddleware,
|
||||||
|
build_skills_backend_factory,
|
||||||
|
create_surfsense_compaction_middleware,
|
||||||
|
default_skills_sources,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.middleware.safe_summarization import (
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
create_safe_summarization_middleware,
|
from app.agents.new_chat.plugin_loader import (
|
||||||
|
PluginContext,
|
||||||
|
load_allowed_plugin_names_from_env,
|
||||||
|
load_plugin_middlewares,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
|
from app.agents.new_chat.subagents import build_specialized_subagents
|
||||||
from app.agents.new_chat.system_prompt import (
|
from app.agents.new_chat.system_prompt import (
|
||||||
build_configurable_system_prompt,
|
build_configurable_system_prompt,
|
||||||
build_surfsense_system_prompt,
|
build_surfsense_system_prompt,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.invalid_tool import (
|
||||||
|
INVALID_TOOL_NAME,
|
||||||
|
invalid_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.registry import (
|
||||||
|
BUILTIN_TOOLS,
|
||||||
|
build_tools_async,
|
||||||
|
get_connector_gated_tools,
|
||||||
|
)
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_prompt_model_name(
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> str | None:
|
||||||
|
"""Resolve the model id to feed to provider-variant detection.
|
||||||
|
|
||||||
|
Preference order (matches the established idiom in
|
||||||
|
``llm_router_service.py`` — see ``params.get("base_model") or
|
||||||
|
params.get("model", "")`` usages there):
|
||||||
|
|
||||||
|
1. ``agent_config.litellm_params["base_model"]`` — required for Azure
|
||||||
|
deployments where ``model_name`` is the deployment slug, not the
|
||||||
|
underlying family. Without this, a deployment named e.g.
|
||||||
|
``"prod-chat-001"`` would silently miss every provider regex.
|
||||||
|
2. ``agent_config.model_name`` — the user's configured model id.
|
||||||
|
3. ``getattr(llm, "model", None)`` — fallback for direct callers that
|
||||||
|
don't supply an ``AgentConfig`` (currently a defensive path; all
|
||||||
|
production callers pass ``agent_config``).
|
||||||
|
|
||||||
|
Returns ``None`` when nothing is available; ``compose_system_prompt``
|
||||||
|
treats that as the ``"default"`` variant (no provider block emitted).
|
||||||
|
"""
|
||||||
|
if agent_config is not None:
|
||||||
|
params = agent_config.litellm_params or {}
|
||||||
|
base_model = params.get("base_model")
|
||||||
|
if isinstance(base_model, str) and base_model.strip():
|
||||||
|
return base_model
|
||||||
|
if agent_config.model_name:
|
||||||
|
return agent_config.model_name
|
||||||
|
return getattr(llm, "model", None)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Connector Type Mapping
|
# Connector Type Mapping
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -164,6 +239,7 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
anon_session_id: str | None = None,
|
anon_session_id: str | None = None,
|
||||||
|
filesystem_selection: FilesystemSelection | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a SurfSense deep agent with configurable tools and prompts.
|
Create a SurfSense deep agent with configurable tools and prompts.
|
||||||
|
|
@ -239,6 +315,21 @@ async def create_surfsense_deep_agent(
|
||||||
"""
|
"""
|
||||||
_t_agent_total = time.perf_counter()
|
_t_agent_total = time.perf_counter()
|
||||||
|
|
||||||
|
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
|
||||||
|
# build-time call in ``llm_config.py``; this run merely adds
|
||||||
|
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
|
||||||
|
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
|
||||||
|
# None or the provider is non-OpenAI-family.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||||
|
|
||||||
|
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||||
|
backend_resolver = build_backend_resolver(
|
||||||
|
filesystem_selection,
|
||||||
|
search_space_id=search_space_id
|
||||||
|
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Discover available connectors and document types for this search space
|
# Discover available connectors and document types for this search space
|
||||||
available_connectors: list[str] | None = None
|
available_connectors: list[str] | None = None
|
||||||
available_document_types: list[str] | None = None
|
available_document_types: list[str] | None = None
|
||||||
|
|
@ -287,107 +378,12 @@ async def create_surfsense_deep_agent(
|
||||||
"llm": llm,
|
"llm": llm,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Disable Notion action tools if no Notion connector is configured
|
|
||||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||||
has_notion_connector = (
|
modified_disabled_tools.extend(get_connector_gated_tools(available_connectors))
|
||||||
available_connectors is not None and "NOTION_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_notion_connector:
|
|
||||||
notion_tools = [
|
|
||||||
"create_notion_page",
|
|
||||||
"update_notion_page",
|
|
||||||
"delete_notion_page",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(notion_tools)
|
|
||||||
|
|
||||||
# Disable Linear action tools if no Linear connector is configured
|
# Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid
|
||||||
has_linear_connector = (
|
# search per turn and surfaces hits as a <priority_documents> hint plus
|
||||||
available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors
|
# `<chunk_index matched="true">` markers inside lazy-loaded XML.
|
||||||
)
|
|
||||||
if not has_linear_connector:
|
|
||||||
linear_tools = [
|
|
||||||
"create_linear_issue",
|
|
||||||
"update_linear_issue",
|
|
||||||
"delete_linear_issue",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(linear_tools)
|
|
||||||
|
|
||||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
|
||||||
has_google_drive_connector = (
|
|
||||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_google_drive_connector:
|
|
||||||
google_drive_tools = [
|
|
||||||
"create_google_drive_file",
|
|
||||||
"delete_google_drive_file",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(google_drive_tools)
|
|
||||||
|
|
||||||
has_dropbox_connector = (
|
|
||||||
available_connectors is not None and "DROPBOX_FILE" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_dropbox_connector:
|
|
||||||
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
|
|
||||||
|
|
||||||
has_onedrive_connector = (
|
|
||||||
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_onedrive_connector:
|
|
||||||
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
|
|
||||||
|
|
||||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
|
||||||
has_google_calendar_connector = (
|
|
||||||
available_connectors is not None
|
|
||||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_google_calendar_connector:
|
|
||||||
calendar_tools = [
|
|
||||||
"create_calendar_event",
|
|
||||||
"update_calendar_event",
|
|
||||||
"delete_calendar_event",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(calendar_tools)
|
|
||||||
|
|
||||||
# Disable Gmail action tools if no Gmail connector is configured
|
|
||||||
has_gmail_connector = (
|
|
||||||
available_connectors is not None
|
|
||||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_gmail_connector:
|
|
||||||
gmail_tools = [
|
|
||||||
"create_gmail_draft",
|
|
||||||
"update_gmail_draft",
|
|
||||||
"send_gmail_email",
|
|
||||||
"trash_gmail_email",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(gmail_tools)
|
|
||||||
|
|
||||||
# Disable Jira action tools if no Jira connector is configured
|
|
||||||
has_jira_connector = (
|
|
||||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_jira_connector:
|
|
||||||
jira_tools = [
|
|
||||||
"create_jira_issue",
|
|
||||||
"update_jira_issue",
|
|
||||||
"delete_jira_issue",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(jira_tools)
|
|
||||||
|
|
||||||
# Disable Confluence action tools if no Confluence connector is configured
|
|
||||||
has_confluence_connector = (
|
|
||||||
available_connectors is not None
|
|
||||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
|
||||||
)
|
|
||||||
if not has_confluence_connector:
|
|
||||||
confluence_tools = [
|
|
||||||
"create_confluence_page",
|
|
||||||
"update_confluence_page",
|
|
||||||
"delete_confluence_page",
|
|
||||||
]
|
|
||||||
modified_disabled_tools.extend(confluence_tools)
|
|
||||||
|
|
||||||
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
|
||||||
if "search_knowledge_base" not in modified_disabled_tools:
|
if "search_knowledge_base" not in modified_disabled_tools:
|
||||||
modified_disabled_tools.append("search_knowledge_base")
|
modified_disabled_tools.append("search_knowledge_base")
|
||||||
|
|
||||||
|
|
@ -399,6 +395,18 @@ async def create_surfsense_deep_agent(
|
||||||
disabled_tools=modified_disabled_tools,
|
disabled_tools=modified_disabled_tools,
|
||||||
additional_tools=list(additional_tools) if additional_tools else None,
|
additional_tools=list(additional_tools) if additional_tools else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register the ``invalid`` tool only when tool-call repair is on. It
|
||||||
|
# is dispatched only when :class:`ToolCallNameRepairMiddleware`
|
||||||
|
# rewrites a malformed call. We intentionally append it AFTER
|
||||||
|
# ``build_tools_async`` so it never appears in the system-prompt
|
||||||
|
# tool list (which is built from the registry, not the bound tool
|
||||||
|
# list).
|
||||||
|
_flags: AgentFeatureFlags = get_flags()
|
||||||
|
if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in {
|
||||||
|
t.name for t in tools
|
||||||
|
}:
|
||||||
|
tools = [*list(tools), invalid_tool]
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] build_tools_async in %.3fs (%d tools)",
|
"[create_agent] build_tools_async in %.3fs (%d tools)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
|
|
@ -409,6 +417,21 @@ async def create_surfsense_deep_agent(
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
_enabled_tool_names = {t.name for t in tools}
|
_enabled_tool_names = {t.name for t in tools}
|
||||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||||
|
|
||||||
|
# Collect generic MCP connector info so the system prompt can route queries
|
||||||
|
# to their tools instead of falling back to "not in knowledge base".
|
||||||
|
_mcp_connector_tools: dict[str, list[str]] = {}
|
||||||
|
for t in tools:
|
||||||
|
meta = getattr(t, "metadata", None) or {}
|
||||||
|
if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"):
|
||||||
|
_mcp_connector_tools.setdefault(
|
||||||
|
meta["mcp_connector_name"],
|
||||||
|
[],
|
||||||
|
).append(t.name)
|
||||||
|
|
||||||
|
if _mcp_connector_tools:
|
||||||
|
_perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools)
|
||||||
|
|
||||||
if agent_config is not None:
|
if agent_config is not None:
|
||||||
system_prompt = build_configurable_system_prompt(
|
system_prompt = build_configurable_system_prompt(
|
||||||
custom_system_instructions=agent_config.system_instructions,
|
custom_system_instructions=agent_config.system_instructions,
|
||||||
|
|
@ -417,18 +440,154 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
|
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
|
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
||||||
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
|
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
|
||||||
|
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||||
|
|
||||||
|
# The middleware stack — and especially ``SubAgentMiddleware`` — is *not*
|
||||||
|
# cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent``
|
||||||
|
# synchronously to compile the general-purpose subagent's full state graph
|
||||||
|
# (every tool + every middleware → pydantic schemas + langgraph compile).
|
||||||
|
# On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it
|
||||||
|
# directly here it blocks the asyncio event loop for the whole streaming
|
||||||
|
# task (and any other coroutine sharing this loop), which is why
|
||||||
|
# "agent creation" wall-clock time used to stretch to ~3-4s. Move the
|
||||||
|
# entire middleware build + main-graph compile into a single
|
||||||
|
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
||||||
|
# event loop stays responsive.
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
agent = await asyncio.to_thread(
|
||||||
|
_build_compiled_agent_blocking,
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
final_system_prompt=final_system_prompt,
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
visibility=visibility,
|
||||||
|
anon_session_id=anon_session_id,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
max_input_tokens=_max_input_tokens,
|
||||||
|
flags=_flags,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] Total agent creation in %.3fs",
|
||||||
|
time.perf_counter() - _t_agent_total,
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
# Tools whose output is too costly / lossy to discard. Keep this
|
||||||
|
# conservative — anything listed here is *never* pruned by
|
||||||
|
# :class:`ContextEditingMiddleware`. The list is filtered against
|
||||||
|
# actually-bound tool names so disabled connectors don't show up here.
|
||||||
|
_PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"generate_report",
|
||||||
|
"generate_resume",
|
||||||
|
"generate_podcast",
|
||||||
|
"generate_video_presentation",
|
||||||
|
"generate_image",
|
||||||
|
# Read-heavy connector reads — recomputing them is expensive
|
||||||
|
"read_email",
|
||||||
|
"search_emails",
|
||||||
|
# The fallback for malformed tool calls — keep its replies visible
|
||||||
|
"invalid",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]:
|
||||||
|
"""Return ``exclude_tools`` derived from the actually-bound tool list.
|
||||||
|
|
||||||
|
Filters :data:`_PRUNE_PROTECTED_TOOL_NAMES` against the bound tools
|
||||||
|
so we never list tools that don't exist (would be a silent no-op).
|
||||||
|
"""
|
||||||
|
enabled = {t.name for t in tools}
|
||||||
|
return tuple(name for name in _PRUNE_PROTECTED_TOOL_NAMES if name in enabled)
|
||||||
|
|
||||||
|
|
||||||
|
# Connector gating: any tool whose ``ToolDefinition.required_connector``
|
||||||
|
# isn't actually wired up gets a synthesized permission deny rule so
|
||||||
|
# execution attempts short-circuit with ``permission_denied`` instead of
|
||||||
|
# bubbling up provider-specific 401/404 errors. Mirrors OpenCode's
|
||||||
|
# ``Permission.disabled`` (declarative, per-tool gating) — replaces the
|
||||||
|
# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE`` substring-heuristic.
|
||||||
|
def _synthesize_connector_deny_rules(
|
||||||
|
*,
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
enabled_tool_names: set[str],
|
||||||
|
) -> list[Rule]:
|
||||||
|
"""Build deny rules for tools whose required connector is not enabled.
|
||||||
|
|
||||||
|
Source of truth is ``ToolDefinition.required_connector`` in
|
||||||
|
:data:`BUILTIN_TOOLS`. A tool only gets a deny rule when:
|
||||||
|
|
||||||
|
1. It is currently bound (``enabled_tool_names``).
|
||||||
|
2. It declares a ``required_connector``.
|
||||||
|
3. That connector is *not* in ``available_connectors``.
|
||||||
|
"""
|
||||||
|
available = set(available_connectors or [])
|
||||||
|
deny: list[Rule] = []
|
||||||
|
for tool_def in BUILTIN_TOOLS:
|
||||||
|
if tool_def.name not in enabled_tool_names:
|
||||||
|
continue
|
||||||
|
rc = tool_def.required_connector
|
||||||
|
if rc and rc not in available:
|
||||||
|
deny.append(Rule(permission=tool_def.name, pattern="*", action="deny"))
|
||||||
|
return deny
|
||||||
|
|
||||||
|
|
||||||
|
def _build_compiled_agent_blocking(
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
final_system_prompt: str,
|
||||||
|
backend_resolver: Any,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
anon_session_id: str | None,
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
available_document_types: list[str] | None,
|
||||||
|
mentioned_document_ids: list[int] | None,
|
||||||
|
max_input_tokens: int | None,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
checkpointer: Checkpointer,
|
||||||
|
):
|
||||||
|
"""Build the middleware stack and compile the agent graph synchronously.
|
||||||
|
|
||||||
|
Runs in a worker thread (see ``asyncio.to_thread`` call site) so the heavy
|
||||||
|
CPU work — most notably ``SubAgentMiddleware.__init__`` eagerly calling
|
||||||
|
``create_agent`` to compile the general-purpose subagent — does not block
|
||||||
|
the event loop.
|
||||||
|
"""
|
||||||
_memory_middleware = MemoryInjectionMiddleware(
|
_memory_middleware = MemoryInjectionMiddleware(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -436,17 +595,24 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
# General-purpose subagent middleware
|
# General-purpose subagent middleware
|
||||||
|
# Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware,
|
||||||
|
# KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it
|
||||||
|
# inherits state and tools from the parent, but should not (a) re-load
|
||||||
|
# anon docs / re-render the tree / re-run hybrid search, or (b) commit at
|
||||||
|
# its own completion (only the top-level agent's aafter_agent commits).
|
||||||
gp_middleware = [
|
gp_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
_memory_middleware,
|
_memory_middleware,
|
||||||
|
FileIntentMiddleware(llm=llm),
|
||||||
SurfSenseFilesystemMiddleware(
|
SurfSenseFilesystemMiddleware(
|
||||||
|
backend=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
),
|
),
|
||||||
create_safe_summarization_middleware(llm, StateBackend),
|
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||||
|
|
@ -456,44 +622,452 @@ async def create_surfsense_deep_agent(
|
||||||
"middleware": gp_middleware,
|
"middleware": gp_middleware,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Specialized user-facing subagents (explore, report_writer,
|
||||||
|
# connector_negotiator). Registered through SubAgentMiddleware alongside
|
||||||
|
# the general-purpose spec so the parent's `task` tool can address them
|
||||||
|
# by name. Off by default until the flag flips so existing deployments
|
||||||
|
# don't see new agent types in the task tool description.
|
||||||
|
specialized_subagents: list[SubAgent] = []
|
||||||
|
if flags.enable_specialized_subagents and not flags.disable_new_agent_stack:
|
||||||
|
try:
|
||||||
|
# Specialized subagents share the parent's filesystem +
|
||||||
|
# todo view so their system prompts (which promise
|
||||||
|
# ``read_file``, ``ls``, ``grep``, ``glob``, ``write_todos``)
|
||||||
|
# actually match runtime behavior. Build *fresh* instances
|
||||||
|
# rather than aliasing the parent's GP middleware to avoid
|
||||||
|
# subtle state coupling across compiled graphs.
|
||||||
|
subagent_extra_middleware: list = [
|
||||||
|
TodoListMiddleware(),
|
||||||
|
SurfSenseFilesystemMiddleware(
|
||||||
|
backend=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
specialized_subagents = build_specialized_subagents(
|
||||||
|
tools=tools,
|
||||||
|
model=llm,
|
||||||
|
extra_middleware=subagent_extra_middleware,
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logging.warning(
|
||||||
|
"Specialized subagent build failed; running without them: %s",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
specialized_subagents = []
|
||||||
|
|
||||||
|
subagent_specs: list[SubAgent] = [general_purpose_spec, *specialized_subagents]
|
||||||
|
|
||||||
# Main agent middleware
|
# Main agent middleware
|
||||||
|
# Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ...
|
||||||
|
# before_agent hooks run in declared order; later injections sit closer to
|
||||||
|
# the latest human turn. Tree (large + cacheable) is injected earliest so
|
||||||
|
# provider-side prefix caching has more material to hit; FileIntent (most
|
||||||
|
# actionable per-turn contract) is injected closest to the user message.
|
||||||
|
#
|
||||||
|
# ``wrap_model_call`` ordering: the FIRST middleware in the list is the
|
||||||
|
# OUTERMOST wrapper. To ensure prune executes before summarization,
|
||||||
|
# place ``SpillingContextEditingMiddleware`` before
|
||||||
|
# ``SurfSenseCompactionMiddleware``. Compaction is the canonical
|
||||||
|
# token-budget defense; the Bedrock buffer-empty defense is folded
|
||||||
|
# into ``SurfSenseCompactionMiddleware``.
|
||||||
|
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
|
||||||
|
_ = flags.enable_compaction_v2 # historical flag; retained for telemetry parity
|
||||||
|
|
||||||
|
# ContextEditing prune. Trigger at 55% of ``max_input_tokens``,
|
||||||
|
# earlier than summarization (~85%). When disabled, no edit runs.
|
||||||
|
context_edit_mw = None
|
||||||
|
if (
|
||||||
|
flags.enable_context_editing
|
||||||
|
and not flags.disable_new_agent_stack
|
||||||
|
and max_input_tokens
|
||||||
|
):
|
||||||
|
spill_edit = SpillToBackendEdit(
|
||||||
|
trigger=int(max_input_tokens * 0.55),
|
||||||
|
clear_at_least=int(max_input_tokens * 0.15),
|
||||||
|
keep=5,
|
||||||
|
exclude_tools=_safe_exclude_tools(tools),
|
||||||
|
clear_tool_inputs=True,
|
||||||
|
)
|
||||||
|
clear_edit = ClearToolUsesEdit(
|
||||||
|
trigger=int(max_input_tokens * 0.55),
|
||||||
|
clear_at_least=int(max_input_tokens * 0.15),
|
||||||
|
keep=5,
|
||||||
|
exclude_tools=_safe_exclude_tools(tools),
|
||||||
|
clear_tool_inputs=True,
|
||||||
|
placeholder="[cleared - older tool output trimmed for context]",
|
||||||
|
)
|
||||||
|
context_edit_mw = SpillingContextEditingMiddleware(
|
||||||
|
edits=[spill_edit, clear_edit],
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resilience knobs: header-aware retry, model fallback, and
|
||||||
|
# per-thread / per-run call-count limits. The fallback / limit
|
||||||
|
# middlewares are vanilla LangChain primitives; ``RetryAfter`` is
|
||||||
|
# SurfSense's header-aware variant (see its module docstring).
|
||||||
|
retry_mw = (
|
||||||
|
RetryAfterMiddleware(max_retries=3)
|
||||||
|
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# Fallback chain — primary is the agent's own model; we add cheap
|
||||||
|
# alternatives. Off by default; only the first call site that
|
||||||
|
# configures the chain via env should enable it.
|
||||||
|
fallback_mw: ModelFallbackMiddleware | None = None
|
||||||
|
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||||
|
try:
|
||||||
|
fallback_mw = ModelFallbackMiddleware(
|
||||||
|
"openai:gpt-4o-mini",
|
||||||
|
"anthropic:claude-3-5-haiku-20241022",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||||
|
fallback_mw = None
|
||||||
|
model_call_limit_mw = (
|
||||||
|
ModelCallLimitMiddleware(
|
||||||
|
thread_limit=120,
|
||||||
|
run_limit=80,
|
||||||
|
exit_behavior="end",
|
||||||
|
)
|
||||||
|
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
tool_call_limit_mw = (
|
||||||
|
ToolCallLimitMiddleware(
|
||||||
|
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||||
|
)
|
||||||
|
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider-compat ``_noop`` injection (mirrors OpenCode's
|
||||||
|
# ``llm.ts`` workaround for providers that reject empty assistant
|
||||||
|
# turns or alternating-role constraints).
|
||||||
|
noop_mw = (
|
||||||
|
NoopInjectionMiddleware()
|
||||||
|
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tool-call name repair (lowercase + ``invalid`` fallback).
|
||||||
|
#
|
||||||
|
# ``registered_tool_names`` MUST cover every tool the model can legitimately
|
||||||
|
# call. That includes the bound ``tools`` list AND every tool provided by
|
||||||
|
# middleware in the stack — ``FilesystemMiddleware`` (read_file, ls, grep,
|
||||||
|
# glob, edit_file, write_file, execute), ``TodoListMiddleware``
|
||||||
|
# (write_todos), ``SubAgentMiddleware`` (task), ``SkillsMiddleware`` (skill
|
||||||
|
# loaders), etc. If we only inspect ``tools`` here, every call to
|
||||||
|
# ``read_file`` / ``ls`` / ``grep`` from the model will be rewritten to
|
||||||
|
# ``invalid`` because the repair middleware doesn't recognize them. The
|
||||||
|
# built-in deepagents middleware aren't in scope yet at this point of the
|
||||||
|
# function but they're added unconditionally below, so we hard-code their
|
||||||
|
# canonical names alongside the dynamic ``tools`` set.
|
||||||
|
repair_mw = None
|
||||||
|
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||||
|
registered_names: set[str] = {t.name for t in tools}
|
||||||
|
# Tools owned by the standard deepagents middleware stack and the
|
||||||
|
# SurfSense filesystem extension.
|
||||||
|
registered_names |= {
|
||||||
|
"write_todos",
|
||||||
|
"ls",
|
||||||
|
"read_file",
|
||||||
|
"write_file",
|
||||||
|
"edit_file",
|
||||||
|
"glob",
|
||||||
|
"grep",
|
||||||
|
"execute",
|
||||||
|
"task",
|
||||||
|
"mkdir",
|
||||||
|
"cd",
|
||||||
|
"pwd",
|
||||||
|
"move_file",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
|
"list_tree",
|
||||||
|
"execute_code",
|
||||||
|
}
|
||||||
|
repair_mw = ToolCallNameRepairMiddleware(
|
||||||
|
registered_tool_names=registered_names,
|
||||||
|
# Disable fuzzy matching to avoid silent rewrites; the
|
||||||
|
# lowercase + ``invalid`` fallback alone covers >95% of
|
||||||
|
# observed model errors.
|
||||||
|
fuzzy_match_threshold=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Doom-loop detector. Off by default until the frontend handles
|
||||||
|
# ``permission == "doom_loop"`` interrupts.
|
||||||
|
doom_loop_mw = (
|
||||||
|
DoomLoopMiddleware(threshold=3)
|
||||||
|
if flags.enable_doom_loop and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# PermissionMiddleware. Layers, earliest -> latest (last match wins,
|
||||||
|
# same evaluation order as OpenCode's ``permission/index.ts``):
|
||||||
|
#
|
||||||
|
# 1. ``surfsense_defaults`` — single ``allow */*`` rule. SurfSense
|
||||||
|
# already runs per-tool HITL (see ``tools/hitl.py``) for mutating
|
||||||
|
# connector tools, so we only want PermissionMiddleware to *deny*
|
||||||
|
# things the user has gated off; the default fallback in
|
||||||
|
# ``permissions.evaluate`` is ``ask``, which would double-prompt
|
||||||
|
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
||||||
|
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
||||||
|
# reject decision into innocent calls.
|
||||||
|
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
|
||||||
|
# the agent is operating against the user's real disk. Cloud mode
|
||||||
|
# has full revision-based revert via ``revert_service``, but
|
||||||
|
# desktop mode hits disk immediately with no undo, so an
|
||||||
|
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
|
||||||
|
# ``write_file`` is unrecoverable. This layer is forced on in
|
||||||
|
# desktop mode regardless of ``enable_permission`` because the
|
||||||
|
# safety net is non-negotiable.
|
||||||
|
# 3. ``connector_synthesized`` — deny rules for tools whose required
|
||||||
|
# connector is not connected to this space. Overrides #1/#2.
|
||||||
|
# 4. (future) user-defined rules from ``agent_permission_rules`` table
|
||||||
|
# via the Agent Permissions UI. Loaded last so they override all.
|
||||||
|
permission_mw: PermissionMiddleware | None = None
|
||||||
|
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
|
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||||
|
# Build the middleware whenever it has work to do: either the user
|
||||||
|
# opted into the rule engine, OR we're in desktop mode and need the
|
||||||
|
# safety rules unconditionally.
|
||||||
|
if permission_enabled or is_desktop_fs:
|
||||||
|
rulesets: list[Ruleset] = [
|
||||||
|
Ruleset(
|
||||||
|
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||||
|
origin="surfsense_defaults",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
if is_desktop_fs:
|
||||||
|
rulesets.append(
|
||||||
|
Ruleset(
|
||||||
|
rules=[
|
||||||
|
Rule(permission="rm", pattern="*", action="ask"),
|
||||||
|
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||||
|
Rule(permission="move_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="write_file", pattern="*", action="ask"),
|
||||||
|
],
|
||||||
|
origin="desktop_safety",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if permission_enabled:
|
||||||
|
synthesized = _synthesize_connector_deny_rules(
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
enabled_tool_names={t.name for t in tools},
|
||||||
|
)
|
||||||
|
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||||
|
permission_mw = PermissionMiddleware(rulesets=rulesets)
|
||||||
|
|
||||||
|
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
||||||
|
# table is migrated. When enabled, persists one row per tool call
|
||||||
|
# with optional reverse_descriptor for
|
||||||
|
# ``POST /api/threads/{thread_id}/revert/{action_id}``. Sits inside
|
||||||
|
# ``permission`` so denied calls aren't logged as completions.
|
||||||
|
action_log_mw: ActionLogMiddleware | None = None
|
||||||
|
if (
|
||||||
|
flags.enable_action_log
|
||||||
|
and not flags.disable_new_agent_stack
|
||||||
|
and thread_id is not None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||||
|
action_log_mw = ActionLogMiddleware(
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_definitions=tool_defs_by_name,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logging.warning(
|
||||||
|
"ActionLogMiddleware init failed; running without it.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
action_log_mw = None
|
||||||
|
|
||||||
|
# Per-thread busy mutex (refuse a second concurrent turn on the same
|
||||||
|
# thread; see :class:`BusyMutexMiddleware` docstring).
|
||||||
|
busy_mutex_mw: BusyMutexMiddleware | None = (
|
||||||
|
BusyMutexMiddleware()
|
||||||
|
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenTelemetry spans (model.call + tool.call). Lives just inside
|
||||||
|
# BusyMutex so it spans every retry/fallback attempt of the current
|
||||||
|
# turn but never wraps a queued/blocked turn.
|
||||||
|
otel_mw: OtelSpanMiddleware | None = (
|
||||||
|
OtelSpanMiddleware()
|
||||||
|
if flags.enable_otel and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plugin entry-point loader. Off by default; opt-in via the
|
||||||
|
# ``SURFSENSE_ENABLE_PLUGIN_LOADER`` flag. The allowlist is read from
|
||||||
|
# the ``SURFSENSE_ALLOWED_PLUGINS`` env var (comma-separated). A future
|
||||||
|
# PR can wire it through ``global_llm_config.yaml``.
|
||||||
|
plugin_middlewares: list[Any] = []
|
||||||
|
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
|
||||||
|
try:
|
||||||
|
allowed_names = load_allowed_plugin_names_from_env()
|
||||||
|
if allowed_names:
|
||||||
|
plugin_middlewares = load_plugin_middlewares(
|
||||||
|
PluginContext.build(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
llm=llm,
|
||||||
|
),
|
||||||
|
allowed_plugin_names=allowed_names,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logging.warning(
|
||||||
|
"Plugin loader failed; continuing without plugins.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
plugin_middlewares = []
|
||||||
|
|
||||||
|
# SkillsMiddleware (deepagents) loads built-in + space-authored
|
||||||
|
# skills via a CompositeBackend. Sources are layered: built-in first,
|
||||||
|
# space last, so a search-space-authored skill of the same name
|
||||||
|
# overrides the bundled one.
|
||||||
|
skills_mw: SkillsMiddleware | None = None
|
||||||
|
if flags.enable_skills and not flags.disable_new_agent_stack:
|
||||||
|
try:
|
||||||
|
skills_factory = build_skills_backend_factory(
|
||||||
|
search_space_id=search_space_id
|
||||||
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
skills_mw = SkillsMiddleware(
|
||||||
|
backend=skills_factory,
|
||||||
|
sources=default_skills_sources(),
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||||
|
skills_mw = None
|
||||||
|
|
||||||
|
# LangChain's LLM-driven tool selection — only enabled for stacks
|
||||||
|
# large enough to need narrowing (>30 tools).
|
||||||
|
selector_mw: LLMToolSelectorMiddleware | None = None
|
||||||
|
if (
|
||||||
|
flags.enable_llm_tool_selector
|
||||||
|
and not flags.disable_new_agent_stack
|
||||||
|
and len(tools) > 30
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
selector_mw = LLMToolSelectorMiddleware(
|
||||||
|
model="openai:gpt-4o-mini",
|
||||||
|
max_tools=12,
|
||||||
|
always_include=[
|
||||||
|
name
|
||||||
|
for name in (
|
||||||
|
"update_memory",
|
||||||
|
"get_connected_accounts",
|
||||||
|
"scrape_webpage",
|
||||||
|
)
|
||||||
|
if name in {t.name for t in tools}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||||
|
selector_mw = None
|
||||||
|
|
||||||
deepagent_middleware = [
|
deepagent_middleware = [
|
||||||
|
# BusyMutex is OUTERMOST: it must wrap the entire stream so no
|
||||||
|
# other turn can sneak in while this one is mid-flight.
|
||||||
|
busy_mutex_mw,
|
||||||
|
# OTel spans sit just inside BusyMutex so each retry attempt
|
||||||
|
# gets its own model.call / tool.call span.
|
||||||
|
otel_mw,
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
_memory_middleware,
|
_memory_middleware,
|
||||||
KnowledgeBaseSearchMiddleware(
|
AnonymousDocumentMiddleware(
|
||||||
|
anon_session_id=anon_session_id,
|
||||||
|
)
|
||||||
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
KnowledgeTreeMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
KnowledgePriorityMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
available_document_types=available_document_types,
|
available_document_types=available_document_types,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
anon_session_id=anon_session_id,
|
|
||||||
),
|
),
|
||||||
|
FileIntentMiddleware(llm=llm),
|
||||||
SurfSenseFilesystemMiddleware(
|
SurfSenseFilesystemMiddleware(
|
||||||
|
backend=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
),
|
),
|
||||||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
KnowledgeBasePersistenceMiddleware(
|
||||||
create_safe_summarization_middleware(llm, StateBackend),
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
# Skill loader. Placed before SubAgentMiddleware so subagents
|
||||||
|
# inherit the same skill metadata (subagent specs reference the
|
||||||
|
# same source paths via ``default_skills_sources()``).
|
||||||
|
skills_mw,
|
||||||
|
SubAgentMiddleware(backend=StateBackend, subagents=subagent_specs),
|
||||||
|
# Tool selection (only when >30 tools and flag on).
|
||||||
|
selector_mw,
|
||||||
|
# Defensive caps, then prune, then summarize.
|
||||||
|
model_call_limit_mw,
|
||||||
|
tool_call_limit_mw,
|
||||||
|
context_edit_mw,
|
||||||
|
summarization_mw,
|
||||||
|
# Provider compatibility + retry chain — placed after prune/compact
|
||||||
|
# so retries happen on the already-trimmed payload.
|
||||||
|
noop_mw,
|
||||||
|
retry_mw,
|
||||||
|
fallback_mw,
|
||||||
|
# Tool-call repair must run after model emits but before
|
||||||
|
# permission / dedup / doom-loop interpret the calls.
|
||||||
|
repair_mw,
|
||||||
|
# Permission deny/ask BEFORE the calls are forwarded to tool nodes.
|
||||||
|
permission_mw,
|
||||||
|
doom_loop_mw,
|
||||||
|
# Action log sits inside permission so denied calls don't appear
|
||||||
|
# as completions, and outside dedup so each unique tool invocation
|
||||||
|
# gets its own row.
|
||||||
|
action_log_mw,
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
# Plugin slot — sits at the tail so plugin-side transforms see the
|
||||||
|
# final tool result. Prompt caching is now applied at LLM build time
|
||||||
|
# via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
|
||||||
|
# caching middleware is needed here. Multiple plugins run in declared
|
||||||
|
# order; loader filtered by the admin allowlist already.
|
||||||
|
*plugin_middlewares,
|
||||||
]
|
]
|
||||||
|
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||||
|
|
||||||
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
|
agent = create_agent(
|
||||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
agent = await asyncio.to_thread(
|
|
||||||
create_agent,
|
|
||||||
llm,
|
llm,
|
||||||
system_prompt=final_system_prompt,
|
system_prompt=final_system_prompt,
|
||||||
tools=tools,
|
tools=list(tools),
|
||||||
middleware=deepagent_middleware,
|
middleware=deepagent_middleware,
|
||||||
context_schema=SurfSenseContextSchema,
|
context_schema=SurfSenseContextSchema,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
)
|
)
|
||||||
agent = agent.with_config(
|
return agent.with_config(
|
||||||
{
|
{
|
||||||
"recursion_limit": 10_000,
|
"recursion_limit": 10_000,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
@ -502,13 +1076,3 @@ async def create_surfsense_deep_agent(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
|
||||||
"[create_agent] Graph compiled (create_agent) in %.3fs",
|
|
||||||
time.perf_counter() - _t0,
|
|
||||||
)
|
|
||||||
|
|
||||||
_perf_log.info(
|
|
||||||
"[create_agent] Total agent creation in %.3fs",
|
|
||||||
time.perf_counter() - _t_agent_total,
|
|
||||||
)
|
|
||||||
return agent
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents.
|
||||||
This module defines the custom state schema used by the SurfSense deep agent.
|
This module defines the custom state schema used by the SurfSense deep agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TypedDict
|
from typing import NotRequired, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class FileOperationContractState(TypedDict):
|
||||||
|
intent: str
|
||||||
|
confidence: float
|
||||||
|
suggested_path: str
|
||||||
|
timestamp: str
|
||||||
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseContextSchema(TypedDict):
|
class SurfSenseContextSchema(TypedDict):
|
||||||
|
|
@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
file_operation_contract: NotRequired[FileOperationContractState]
|
||||||
|
turn_id: NotRequired[str]
|
||||||
|
request_id: NotRequired[str]
|
||||||
# These are runtime-injected and won't be serialized
|
# These are runtime-injected and won't be serialized
|
||||||
# db_session and connector_service are passed when invoking the agent
|
# db_session and connector_service are passed when invoking the agent
|
||||||
|
|
|
||||||
103
surfsense_backend/app/agents/new_chat/document_xml.py
Normal file
103
surfsense_backend/app/agents/new_chat/document_xml.py
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""Shared XML builder for KB documents.
|
||||||
|
|
||||||
|
Produces the citation-friendly XML used by every read of a knowledge-base
|
||||||
|
document (lazy-loaded by :class:`KBPostgresBackend` and synthetic anonymous
|
||||||
|
files). The XML carries a ``<chunk_index>`` near the top so the LLM can jump
|
||||||
|
directly to matched-chunk line ranges via ``read_file(offset=…, limit=…)``.
|
||||||
|
|
||||||
|
Extracted from the original ``knowledge_search.py`` so the backend, the
|
||||||
|
priority middleware, and any future renderer share a single implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def build_document_xml(
|
||||||
|
document: dict[str, Any],
|
||||||
|
matched_chunk_ids: set[int] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Dict shape produced by hybrid search / lazy-load helpers.
|
||||||
|
Expected keys: ``document`` (with ``id``, ``title``,
|
||||||
|
``document_type``, ``metadata``) and ``chunks``
|
||||||
|
(list of ``{chunk_id, content}``).
|
||||||
|
matched_chunk_ids: Optional set of chunk IDs to flag as
|
||||||
|
``matched="true"`` in the chunk index.
|
||||||
|
"""
|
||||||
|
matched = matched_chunk_ids or set()
|
||||||
|
|
||||||
|
doc_meta = document.get("document") or {}
|
||||||
|
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
||||||
|
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
||||||
|
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
||||||
|
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
||||||
|
url = (
|
||||||
|
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
||||||
|
)
|
||||||
|
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||||
|
|
||||||
|
metadata_lines: list[str] = [
|
||||||
|
"<document>",
|
||||||
|
"<document_metadata>",
|
||||||
|
f" <document_id>{document_id}</document_id>",
|
||||||
|
f" <document_type>{document_type}</document_type>",
|
||||||
|
f" <title><![CDATA[{title}]]></title>",
|
||||||
|
f" <url><![CDATA[{url}]]></url>",
|
||||||
|
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||||
|
"</document_metadata>",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = document.get("chunks") or []
|
||||||
|
chunk_entries: list[tuple[int | None, str]] = []
|
||||||
|
if isinstance(chunks, list):
|
||||||
|
for chunk in chunks:
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
continue
|
||||||
|
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||||
|
chunk_content = str(chunk.get("content", "")).strip()
|
||||||
|
if not chunk_content:
|
||||||
|
continue
|
||||||
|
if chunk_id is None:
|
||||||
|
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
||||||
|
else:
|
||||||
|
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
||||||
|
chunk_entries.append((chunk_id, xml))
|
||||||
|
|
||||||
|
index_overhead = 1 + len(chunk_entries) + 1 + 1 + 1
|
||||||
|
first_chunk_line = len(metadata_lines) + index_overhead + 1
|
||||||
|
|
||||||
|
current_line = first_chunk_line
|
||||||
|
index_entry_lines: list[str] = []
|
||||||
|
for cid, xml_str in chunk_entries:
|
||||||
|
num_lines = xml_str.count("\n") + 1
|
||||||
|
end_line = current_line + num_lines - 1
|
||||||
|
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
||||||
|
if cid is not None:
|
||||||
|
index_entry_lines.append(
|
||||||
|
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_entry_lines.append(
|
||||||
|
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||||
|
)
|
||||||
|
current_line = end_line + 1
|
||||||
|
|
||||||
|
lines = metadata_lines.copy()
|
||||||
|
lines.append("<chunk_index>")
|
||||||
|
lines.extend(index_entry_lines)
|
||||||
|
lines.append("</chunk_index>")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("<document_content>")
|
||||||
|
for _, xml_str in chunk_entries:
|
||||||
|
lines.append(xml_str)
|
||||||
|
lines.extend(["</document_content>", "</document>"])
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_document_xml"]
|
||||||
95
surfsense_backend/app/agents/new_chat/errors.py
Normal file
95
surfsense_backend/app/agents/new_chat/errors.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
Typed error taxonomy for the SurfSense agent stack.
|
||||||
|
|
||||||
|
Used by:
|
||||||
|
- :class:`RetryAfterMiddleware` — its ``retry_on`` callable consults
|
||||||
|
the error code to decide whether a retry is appropriate.
|
||||||
|
- :class:`PermissionMiddleware` — emits ``code="permission_denied"``
|
||||||
|
errors when a deny rule trips.
|
||||||
|
- All tools — return :class:`StreamingError` payloads in
|
||||||
|
``ToolMessage.additional_kwargs["error"]`` so the model and the
|
||||||
|
retry/permission layers share a contract.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
ErrorCode = Literal[
|
||||||
|
"rate_limit",
|
||||||
|
"auth",
|
||||||
|
"tool_validation",
|
||||||
|
"tool_runtime",
|
||||||
|
"context_overflow",
|
||||||
|
"provider",
|
||||||
|
"permission_denied",
|
||||||
|
"doom_loop",
|
||||||
|
"busy",
|
||||||
|
"cancelled",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingError(BaseModel):
|
||||||
|
"""Structured error payload attached to ``ToolMessage.additional_kwargs["error"]``.
|
||||||
|
|
||||||
|
Tools and middleware emit this so retry, permission, and routing
|
||||||
|
layers can decide what to do without parsing free-form strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
code: ErrorCode
|
||||||
|
retryable: bool = False
|
||||||
|
suggestion: str | None = None
|
||||||
|
correlation_id: str | None = None
|
||||||
|
detail: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Free-form additional context. Not surfaced to the model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
|
class RejectedError(Exception):
|
||||||
|
"""Raised when the user rejects a permission ask without feedback.
|
||||||
|
|
||||||
|
Caught by :class:`PermissionMiddleware`; the agent stops the current
|
||||||
|
tool fan-out and surfaces a user-facing rejection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, tool: str | None = None, pattern: str | None = None) -> None:
|
||||||
|
super().__init__(f"Permission rejected for tool {tool!r}, pattern {pattern!r}")
|
||||||
|
self.tool = tool
|
||||||
|
self.pattern = pattern
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectedError(Exception):
|
||||||
|
"""Raised when the user rejects a permission ask *with* feedback.
|
||||||
|
|
||||||
|
The :class:`PermissionMiddleware` translates the feedback into a
|
||||||
|
synthetic ``ToolMessage`` so the model sees the user's correction
|
||||||
|
and can retry the request differently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, feedback: str, *, tool: str | None = None) -> None:
|
||||||
|
super().__init__(feedback)
|
||||||
|
self.feedback = feedback
|
||||||
|
self.tool = tool
|
||||||
|
|
||||||
|
|
||||||
|
class BusyError(Exception):
|
||||||
|
"""Raised when a second prompt arrives while the same thread is mid-stream."""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str | None = None) -> None:
|
||||||
|
super().__init__("Thread is busy with another request")
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BusyError",
|
||||||
|
"CorrectedError",
|
||||||
|
"ErrorCode",
|
||||||
|
"RejectedError",
|
||||||
|
"StreamingError",
|
||||||
|
]
|
||||||
235
surfsense_backend/app/agents/new_chat/feature_flags.py
Normal file
235
surfsense_backend/app/agents/new_chat/feature_flags.py
Normal file
|
|
@ -0,0 +1,235 @@
|
||||||
|
"""
|
||||||
|
Feature flags for the SurfSense new_chat agent stack.
|
||||||
|
|
||||||
|
These flags gate the newer agent middleware (some ported from OpenCode,
|
||||||
|
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
||||||
|
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
|
||||||
|
image updates work even when older installs do not have newly introduced
|
||||||
|
environment variables. Risky/experimental integrations stay default OFF,
|
||||||
|
and the master kill-switch can still disable everything new.
|
||||||
|
|
||||||
|
All new middleware checks its flag at agent build time. If the master
|
||||||
|
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
||||||
|
middleware is disabled regardless of its individual flag. This gives
|
||||||
|
operators a single switch to revert to pre-port behavior.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
Defaults:
|
||||||
|
|
||||||
|
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||||
|
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||||
|
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||||
|
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||||
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
|
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||||
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||||
|
|
||||||
|
Master kill-switch (overrides everything else):
|
||||||
|
|
||||||
|
SURFSENSE_DISABLE_NEW_AGENT_STACK=true
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _env_bool(name: str, default: bool) -> bool:
|
||||||
|
"""Parse a boolean env var. Accepts ``1``/``true``/``yes``/``on`` (case-insensitive)."""
|
||||||
|
raw = os.environ.get(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
return raw.strip().lower() in ("1", "true", "yes", "on")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentFeatureFlags:
|
||||||
|
"""Resolved feature-flag state for one agent build.
|
||||||
|
|
||||||
|
Constructed via :meth:`from_env`. The dataclass is frozen so it can be
|
||||||
|
safely shared across coroutines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Master kill-switch — when true, every flag below resolves to False
|
||||||
|
# regardless of its env value. Used for rapid rollback.
|
||||||
|
disable_new_agent_stack: bool = False
|
||||||
|
|
||||||
|
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
||||||
|
enable_context_editing: bool = True
|
||||||
|
enable_compaction_v2: bool = True
|
||||||
|
enable_retry_after: bool = True
|
||||||
|
enable_model_fallback: bool = False
|
||||||
|
enable_model_call_limit: bool = True
|
||||||
|
enable_tool_call_limit: bool = True
|
||||||
|
enable_tool_call_repair: bool = True
|
||||||
|
enable_doom_loop: bool = True
|
||||||
|
|
||||||
|
# Safety — permissions, concurrency, tool-set narrowing
|
||||||
|
enable_permission: bool = True
|
||||||
|
enable_busy_mutex: bool = True
|
||||||
|
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
||||||
|
|
||||||
|
# Skills + subagents
|
||||||
|
enable_skills: bool = True
|
||||||
|
enable_specialized_subagents: bool = True
|
||||||
|
enable_kb_planner_runnable: bool = True
|
||||||
|
|
||||||
|
# Snapshot / revert
|
||||||
|
enable_action_log: bool = True
|
||||||
|
enable_revert_route: bool = True
|
||||||
|
|
||||||
|
# Streaming parity v2 — opt in to LangChain's structured
|
||||||
|
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||||
|
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||||
|
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||||
|
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||||
|
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||||
|
# ship unconditionally because they're forward-compatible.
|
||||||
|
enable_stream_parity_v2: bool = True
|
||||||
|
|
||||||
|
# Plugins
|
||||||
|
enable_plugin_loader: bool = False
|
||||||
|
|
||||||
|
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||||
|
enable_otel: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> AgentFeatureFlags:
|
||||||
|
"""Read flags from environment.
|
||||||
|
|
||||||
|
Master kill-switch is evaluated first; when set, all other flags
|
||||||
|
force to False.
|
||||||
|
"""
|
||||||
|
master_off = _env_bool("SURFSENSE_DISABLE_NEW_AGENT_STACK", False)
|
||||||
|
if master_off:
|
||||||
|
logger.info(
|
||||||
|
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
||||||
|
"middleware is forced OFF for this build."
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
disable_new_agent_stack=True,
|
||||||
|
enable_context_editing=False,
|
||||||
|
enable_compaction_v2=False,
|
||||||
|
enable_retry_after=False,
|
||||||
|
enable_model_fallback=False,
|
||||||
|
enable_model_call_limit=False,
|
||||||
|
enable_tool_call_limit=False,
|
||||||
|
enable_tool_call_repair=False,
|
||||||
|
enable_doom_loop=False,
|
||||||
|
enable_permission=False,
|
||||||
|
enable_busy_mutex=False,
|
||||||
|
enable_llm_tool_selector=False,
|
||||||
|
enable_skills=False,
|
||||||
|
enable_specialized_subagents=False,
|
||||||
|
enable_kb_planner_runnable=False,
|
||||||
|
enable_action_log=False,
|
||||||
|
enable_revert_route=False,
|
||||||
|
enable_stream_parity_v2=False,
|
||||||
|
enable_plugin_loader=False,
|
||||||
|
enable_otel=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
disable_new_agent_stack=False,
|
||||||
|
# Agent quality
|
||||||
|
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
|
||||||
|
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
|
||||||
|
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
|
||||||
|
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||||
|
enable_model_call_limit=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
|
||||||
|
),
|
||||||
|
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
|
||||||
|
enable_tool_call_repair=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
|
||||||
|
),
|
||||||
|
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
|
||||||
|
# Safety
|
||||||
|
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
|
||||||
|
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
|
||||||
|
enable_llm_tool_selector=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
||||||
|
),
|
||||||
|
# Skills + subagents
|
||||||
|
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
|
||||||
|
enable_specialized_subagents=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
|
||||||
|
),
|
||||||
|
enable_kb_planner_runnable=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
|
||||||
|
),
|
||||||
|
# Snapshot / revert
|
||||||
|
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||||
|
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||||
|
# Streaming parity v2
|
||||||
|
enable_stream_parity_v2=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
||||||
|
),
|
||||||
|
# Plugins
|
||||||
|
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||||
|
# Observability
|
||||||
|
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def any_new_middleware_enabled(self) -> bool:
|
||||||
|
"""Return True if any new middleware flag is on."""
|
||||||
|
if self.disable_new_agent_stack:
|
||||||
|
return False
|
||||||
|
return any(
|
||||||
|
(
|
||||||
|
self.enable_context_editing,
|
||||||
|
self.enable_compaction_v2,
|
||||||
|
self.enable_retry_after,
|
||||||
|
self.enable_model_fallback,
|
||||||
|
self.enable_model_call_limit,
|
||||||
|
self.enable_tool_call_limit,
|
||||||
|
self.enable_tool_call_repair,
|
||||||
|
self.enable_doom_loop,
|
||||||
|
self.enable_permission,
|
||||||
|
self.enable_busy_mutex,
|
||||||
|
self.enable_llm_tool_selector,
|
||||||
|
self.enable_skills,
|
||||||
|
self.enable_specialized_subagents,
|
||||||
|
self.enable_kb_planner_runnable,
|
||||||
|
self.enable_action_log,
|
||||||
|
self.enable_revert_route,
|
||||||
|
self.enable_plugin_loader,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level cache. Read once at import time so the values are consistent
|
||||||
|
# across the process lifetime. Use ``reload_for_tests`` to reset in tests.
|
||||||
|
_FLAGS: AgentFeatureFlags | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_flags() -> AgentFeatureFlags:
|
||||||
|
"""Return the resolved feature-flag state, caching on first call."""
|
||||||
|
global _FLAGS
|
||||||
|
if _FLAGS is None:
|
||||||
|
_FLAGS = AgentFeatureFlags.from_env()
|
||||||
|
return _FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def reload_for_tests() -> AgentFeatureFlags:
|
||||||
|
"""Force a fresh read from env. Tests should call this after monkeypatching env."""
|
||||||
|
global _FLAGS
|
||||||
|
_FLAGS = AgentFeatureFlags.from_env()
|
||||||
|
return _FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentFeatureFlags",
|
||||||
|
"get_flags",
|
||||||
|
"reload_for_tests",
|
||||||
|
]
|
||||||
63
surfsense_backend/app/agents/new_chat/filesystem_backends.py
Normal file
63
surfsense_backend/app/agents/new_chat/filesystem_backends.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Filesystem backend resolver for cloud and desktop-local modes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import BackendProtocol
|
||||||
|
from deepagents.backends.state import StateBackend
|
||||||
|
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
|
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||||
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
|
MultiRootLocalFolderBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=64)
|
||||||
|
def _cached_multi_root_backend(
|
||||||
|
mounts: tuple[tuple[str, str], ...],
|
||||||
|
) -> MultiRootLocalFolderBackend:
|
||||||
|
return MultiRootLocalFolderBackend(mounts)
|
||||||
|
|
||||||
|
|
||||||
|
def build_backend_resolver(
|
||||||
|
selection: FilesystemSelection,
|
||||||
|
*,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> Callable[[ToolRuntime], BackendProtocol]:
|
||||||
|
"""Create deepagents backend resolver for the selected filesystem mode.
|
||||||
|
|
||||||
|
In cloud mode the resolver returns a fresh :class:`KBPostgresBackend`
|
||||||
|
bound to the current ``runtime`` so the backend can read staging state
|
||||||
|
(``staged_dirs``, ``pending_moves``, ``files`` cache, ``kb_anon_doc``,
|
||||||
|
``kb_matched_chunk_ids``) for each tool call. When no ``search_space_id``
|
||||||
|
is provided, the resolver falls back to :class:`StateBackend` (used by
|
||||||
|
sub-agents and tests that don't need DB-backed reads).
|
||||||
|
|
||||||
|
Desktop-local mode unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
|
||||||
|
|
||||||
|
def _resolve_local(_runtime: ToolRuntime) -> MultiRootLocalFolderBackend:
|
||||||
|
mounts = tuple(
|
||||||
|
(entry.mount_id, entry.root_path) for entry in selection.local_mounts
|
||||||
|
)
|
||||||
|
return _cached_multi_root_backend(mounts)
|
||||||
|
|
||||||
|
return _resolve_local
|
||||||
|
|
||||||
|
if search_space_id is not None:
|
||||||
|
|
||||||
|
def _resolve_kb(runtime: ToolRuntime) -> BackendProtocol:
|
||||||
|
return KBPostgresBackend(search_space_id, runtime)
|
||||||
|
|
||||||
|
return _resolve_kb
|
||||||
|
|
||||||
|
def _resolve_state(runtime: ToolRuntime) -> StateBackend:
|
||||||
|
return StateBackend(runtime)
|
||||||
|
|
||||||
|
return _resolve_state
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""Filesystem mode contracts and selection helpers for chat sessions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class FilesystemMode(StrEnum):
|
||||||
|
"""Supported filesystem backends for agent tool execution."""
|
||||||
|
|
||||||
|
CLOUD = "cloud"
|
||||||
|
DESKTOP_LOCAL_FOLDER = "desktop_local_folder"
|
||||||
|
|
||||||
|
|
||||||
|
class ClientPlatform(StrEnum):
|
||||||
|
"""Client runtime reported by the caller."""
|
||||||
|
|
||||||
|
WEB = "web"
|
||||||
|
DESKTOP = "desktop"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class LocalFilesystemMount:
|
||||||
|
"""Canonical mount mapping provided by desktop runtime."""
|
||||||
|
|
||||||
|
mount_id: str
|
||||||
|
root_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class FilesystemSelection:
|
||||||
|
"""Resolved filesystem selection for a single chat request."""
|
||||||
|
|
||||||
|
mode: FilesystemMode = FilesystemMode.CLOUD
|
||||||
|
client_platform: ClientPlatform = ClientPlatform.WEB
|
||||||
|
local_mounts: tuple[LocalFilesystemMount, ...] = ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_local_mode(self) -> bool:
|
||||||
|
return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
178
surfsense_backend/app/agents/new_chat/filesystem_state.py
Normal file
178
surfsense_backend/app/agents/new_chat/filesystem_state.py
Normal file
|
|
@ -0,0 +1,178 @@
|
||||||
|
"""LangGraph state schema additions used by the SurfSense filesystem agent.
|
||||||
|
|
||||||
|
This schema extends deepagents' upstream :class:`FilesystemState` with the
|
||||||
|
extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
||||||
|
|
||||||
|
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||||
|
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||||
|
* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs.
|
||||||
|
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||||
|
* ``pending_deletes`` — pending ``rm`` requests (cloud only).
|
||||||
|
* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only).
|
||||||
|
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||||
|
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||||
|
* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for
|
||||||
|
dirty paths; used to bind the per-path snapshot to an action_id.
|
||||||
|
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||||
|
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||||
|
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||||
|
* ``tree_version`` — bumped by persistence; invalidates the tree render cache.
|
||||||
|
|
||||||
|
Tools mutate these fields ONLY via ``Command(update=...)`` returns; the
|
||||||
|
reducers in :mod:`app.agents.new_chat.state_reducers` handle merging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Any, NotRequired
|
||||||
|
|
||||||
|
from deepagents.middleware.filesystem import FilesystemState
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from app.agents.new_chat.state_reducers import (
|
||||||
|
_add_unique_reducer,
|
||||||
|
_dict_merge_with_tombstones_reducer,
|
||||||
|
_list_append_reducer,
|
||||||
|
_replace_reducer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingMove(TypedDict, total=False):
|
||||||
|
"""A staged move_file operation pending end-of-turn commit.
|
||||||
|
|
||||||
|
``tool_call_id`` is optional for backward compatibility with checkpoints
|
||||||
|
written before the snapshot/revert pipeline was wired up; new entries
|
||||||
|
always include it so the persistence body can resolve an action_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
source: str
|
||||||
|
dest: str
|
||||||
|
overwrite: bool
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PendingDelete(TypedDict, total=False):
|
||||||
|
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
|
||||||
|
|
||||||
|
``tool_call_id`` is required for new entries (it's the binding key used
|
||||||
|
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
|
||||||
|
:class:`AgentActionLog` row and bind the snapshot to it). Marked
|
||||||
|
``total=False`` only to tolerate older checkpoint payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class KbPriorityEntry(TypedDict, total=False):
|
||||||
|
path: str
|
||||||
|
score: float
|
||||||
|
document_id: int | None
|
||||||
|
title: str
|
||||||
|
mentioned: bool
|
||||||
|
|
||||||
|
|
||||||
|
class KbAnonDoc(TypedDict, total=False):
|
||||||
|
"""In-memory anonymous-session document loaded from Redis."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
chunks: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class SurfSenseFilesystemState(FilesystemState):
|
||||||
|
"""Filesystem state used by the SurfSense agent (cloud + desktop).
|
||||||
|
|
||||||
|
Extends deepagents' :class:`FilesystemState` (which provides ``files``)
|
||||||
|
with cloud-mode staging fields and search-priority hints. All extra fields
|
||||||
|
are reducer-backed so that ``Command(update=...)`` payloads merge cleanly
|
||||||
|
across agent steps and across checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cwd: NotRequired[Annotated[str, _replace_reducer]]
|
||||||
|
"""Current working directory.
|
||||||
|
|
||||||
|
Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in
|
||||||
|
desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
|
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||||
|
|
||||||
|
staged_dir_tool_calls: NotRequired[
|
||||||
|
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
|
||||||
|
|
||||||
|
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
|
||||||
|
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
|
||||||
|
Kept separate from ``staged_dirs`` (which stays a unique-string list)
|
||||||
|
to avoid breaking ``_add_unique_reducer`` semantics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||||
|
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||||
|
|
||||||
|
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
|
||||||
|
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
|
||||||
|
|
||||||
|
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
|
||||||
|
uniqueness is enforced inside the commit body, not the reducer (we keep
|
||||||
|
``tool_call_id`` per occurrence so snapshot binding works).
|
||||||
|
"""
|
||||||
|
|
||||||
|
pending_dir_deletes: NotRequired[
|
||||||
|
Annotated[list[PendingDelete], _list_append_reducer]
|
||||||
|
]
|
||||||
|
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
|
||||||
|
|
||||||
|
Same shape as :data:`pending_deletes`. Commit body re-verifies the
|
||||||
|
folder is empty (in-DB AND with this turn's pending changes accounted
|
||||||
|
for) before issuing the DELETE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
doc_id_by_path: NotRequired[
|
||||||
|
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""virtual_path -> ``Document.id`` for lazily loaded files.
|
||||||
|
|
||||||
|
Populated on first read of a KB document. Used by edit_file/move_file/
|
||||||
|
aafter_agent to map paths back to a real DB row. ``None`` values delete
|
||||||
|
the key (tombstones).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
|
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||||
|
|
||||||
|
dirty_path_tool_calls: NotRequired[
|
||||||
|
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
|
||||||
|
|
||||||
|
The persistence body coalesces multiple writes/edits to the same path
|
||||||
|
into one snapshot per turn. This map captures the most-recent
|
||||||
|
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
|
||||||
|
to the latest action_id (the one the user is most likely to revert).
|
||||||
|
"""
|
||||||
|
|
||||||
|
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||||
|
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||||
|
|
||||||
|
kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]]
|
||||||
|
"""Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search."""
|
||||||
|
|
||||||
|
kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]]
|
||||||
|
"""Anonymous-session document loaded from Redis (read-only, no DB row)."""
|
||||||
|
|
||||||
|
tree_version: NotRequired[Annotated[int, _replace_reducer]]
|
||||||
|
"""Monotonically increasing counter; bumped when commits change the KB tree."""
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KbAnonDoc",
|
||||||
|
"KbPriorityEntry",
|
||||||
|
"PendingDelete",
|
||||||
|
"PendingMove",
|
||||||
|
"SurfSenseFilesystemState",
|
||||||
|
]
|
||||||
|
|
@ -27,6 +27,7 @@ from litellm import get_model_info
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.services.llm_router_service import (
|
from app.services.llm_router_service import (
|
||||||
AUTO_MODE_ID,
|
AUTO_MODE_ID,
|
||||||
ChatLiteLLMRouter,
|
ChatLiteLLMRouter,
|
||||||
|
|
@ -89,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
# Provider mapping for LiteLLM model string construction
|
# Provider mapping for LiteLLM model string construction.
|
||||||
PROVIDER_MAP = {
|
#
|
||||||
"OPENAI": "openai",
|
# Single source of truth lives in
|
||||||
"ANTHROPIC": "anthropic",
|
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||||
"GROQ": "groq",
|
# runs during ``app.config`` class-body init) can resolve provider
|
||||||
"COHERE": "cohere",
|
# prefixes without dragging the agent / tools tree into module load
|
||||||
"GOOGLE": "gemini",
|
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||||
"OLLAMA": "ollama_chat",
|
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||||
"MISTRAL": "mistral",
|
# tests) keep working unchanged.
|
||||||
"AZURE_OPENAI": "azure",
|
from app.services.provider_capabilities import ( # noqa: E402
|
||||||
"OPENROUTER": "openrouter",
|
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||||
"XAI": "xai",
|
)
|
||||||
"BEDROCK": "bedrock",
|
|
||||||
"VERTEX_AI": "vertex_ai",
|
|
||||||
"TOGETHER_AI": "together_ai",
|
|
||||||
"FIREWORKS_AI": "fireworks_ai",
|
|
||||||
"DEEPSEEK": "openai",
|
|
||||||
"ALIBABA_QWEN": "openai",
|
|
||||||
"MOONSHOT": "openai",
|
|
||||||
"ZHIPU": "openai",
|
|
||||||
"GITHUB_MODELS": "github",
|
|
||||||
"REPLICATE": "replicate",
|
|
||||||
"PERPLEXITY": "perplexity",
|
|
||||||
"ANYSCALE": "anyscale",
|
|
||||||
"DEEPINFRA": "deepinfra",
|
|
||||||
"CEREBRAS": "cerebras",
|
|
||||||
"SAMBANOVA": "sambanova",
|
|
||||||
"AI21": "ai21",
|
|
||||||
"CLOUDFLARE": "cloudflare",
|
|
||||||
"DATABRICKS": "databricks",
|
|
||||||
"COMETAPI": "cometapi",
|
|
||||||
"HUGGINGFACE": "huggingface",
|
|
||||||
"MINIMAX": "openai",
|
|
||||||
"CUSTOM": "custom",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||||
|
|
@ -177,6 +155,17 @@ class AgentConfig:
|
||||||
anonymous_enabled: bool = False
|
anonymous_enabled: bool = False
|
||||||
quota_reserve_tokens: int | None = None
|
quota_reserve_tokens: int | None = None
|
||||||
|
|
||||||
|
# Capability flag: best-effort True for the chat selector / catalog.
|
||||||
|
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||||
|
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||||
|
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||||
|
# is the conservative-allow stance — the streaming-task safety net
|
||||||
|
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||||
|
# actually blocks a request. Setting this to False here without an
|
||||||
|
# authoritative source would silently hide vision-capable models
|
||||||
|
# (the regression we're fixing).
|
||||||
|
supports_image_input: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_auto_mode(cls) -> "AgentConfig":
|
def from_auto_mode(cls) -> "AgentConfig":
|
||||||
"""
|
"""
|
||||||
|
|
@ -202,6 +191,12 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
|
# Auto routes across the configured pool, which usually
|
||||||
|
# contains at least one vision-capable deployment; the router
|
||||||
|
# will surface a 404 from a non-vision deployment as a normal
|
||||||
|
# ``allowed_fails`` event and fail over rather than blocking
|
||||||
|
# the request outright.
|
||||||
|
supports_image_input=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -215,10 +210,24 @@ class AgentConfig:
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance
|
AgentConfig instance
|
||||||
"""
|
"""
|
||||||
return cls(
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
provider=config.provider.value
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
|
provider_value = (
|
||||||
|
config.provider.value
|
||||||
if hasattr(config.provider, "value")
|
if hasattr(config.provider, "value")
|
||||||
else str(config.provider),
|
else str(config.provider)
|
||||||
|
)
|
||||||
|
litellm_params = config.litellm_params or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
provider=provider_value,
|
||||||
model_name=config.model_name,
|
model_name=config.model_name,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
api_base=config.api_base,
|
api_base=config.api_base,
|
||||||
|
|
@ -234,6 +243,16 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
|
# BYOK rows have no operator-curated capability flag, so we
|
||||||
|
# ask LiteLLM (default-allow on unknown). The streaming
|
||||||
|
# safety net still blocks if the model is *explicitly*
|
||||||
|
# marked text-only.
|
||||||
|
supports_image_input=derive_supports_image_input(
|
||||||
|
provider=provider_value,
|
||||||
|
model_name=config.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=config.custom_provider,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -252,15 +271,46 @@ class AgentConfig:
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance
|
AgentConfig instance
|
||||||
"""
|
"""
|
||||||
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
# Get system instructions from YAML, default to empty string
|
# Get system instructions from YAML, default to empty string
|
||||||
system_instructions = yaml_config.get("system_instructions", "")
|
system_instructions = yaml_config.get("system_instructions", "")
|
||||||
|
|
||||||
|
provider = yaml_config.get("provider", "").upper()
|
||||||
|
model_name = yaml_config.get("model_name", "")
|
||||||
|
custom_provider = yaml_config.get("custom_provider")
|
||||||
|
litellm_params = yaml_config.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||||
|
# OpenRouter modalities. The YAML loader already populates this
|
||||||
|
# field, but this method is also called from
|
||||||
|
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||||
|
# so we re-derive here for safety. The bool() coercion preserves
|
||||||
|
# the loader's behaviour for explicit ``true`` / ``false``
|
||||||
|
# strings that PyYAML may surface.
|
||||||
|
if "supports_image_input" in yaml_config:
|
||||||
|
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||||
|
else:
|
||||||
|
supports_image_input = derive_supports_image_input(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
provider=yaml_config.get("provider", "").upper(),
|
provider=provider,
|
||||||
model_name=yaml_config.get("model_name", ""),
|
model_name=model_name,
|
||||||
api_key=yaml_config.get("api_key", ""),
|
api_key=yaml_config.get("api_key", ""),
|
||||||
api_base=yaml_config.get("api_base"),
|
api_base=yaml_config.get("api_base"),
|
||||||
custom_provider=yaml_config.get("custom_provider"),
|
custom_provider=custom_provider,
|
||||||
litellm_params=yaml_config.get("litellm_params"),
|
litellm_params=yaml_config.get("litellm_params"),
|
||||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||||
system_instructions=system_instructions if system_instructions else None,
|
system_instructions=system_instructions if system_instructions else None,
|
||||||
|
|
@ -275,6 +325,7 @@ class AgentConfig:
|
||||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||||
|
supports_image_input=supports_image_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -494,6 +545,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||||
|
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||||
|
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||||
|
# in a structured form, so we set only the universal injection points.
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -518,7 +574,16 @@ def create_chat_litellm_from_agent_config(
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return get_auto_mode_llm()
|
router_llm = get_auto_mode_llm()
|
||||||
|
if router_llm is not None:
|
||||||
|
# Universal cache_control_injection_points only — auto-mode
|
||||||
|
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||||
|
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||||
|
# would strip them at the provider boundary anyway, but
|
||||||
|
# there's no point setting them when we don't know the
|
||||||
|
# destination.
|
||||||
|
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||||
|
return router_llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -549,4 +614,9 @@ def create_chat_litellm_from_agent_config(
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||||
|
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||||
|
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||||
|
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||||
return llm
|
return llm
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,83 @@
|
||||||
"""Middleware components for the SurfSense new chat agent."""
|
"""Middleware components for the SurfSense new chat agent."""
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
||||||
|
from app.agents.new_chat.middleware.anonymous_document import (
|
||||||
|
AnonymousDocumentMiddleware,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import BusyMutexMiddleware
|
||||||
|
from app.agents.new_chat.middleware.compaction import (
|
||||||
|
SurfSenseCompactionMiddleware,
|
||||||
|
create_surfsense_compaction_middleware,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.context_editing import (
|
||||||
|
ClearToolUsesEdit,
|
||||||
|
SpillingContextEditingMiddleware,
|
||||||
|
SpillToBackendEdit,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware
|
||||||
|
from app.agents.new_chat.middleware.file_intent import (
|
||||||
|
FileIntentMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.filesystem import (
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
|
KnowledgeBasePersistenceMiddleware,
|
||||||
|
commit_staged_filesystem_state,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
|
KnowledgePriorityMiddleware,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.knowledge_tree import (
|
||||||
|
KnowledgeTreeMiddleware,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.middleware.memory_injection import (
|
from app.agents.new_chat.middleware.memory_injection import (
|
||||||
MemoryInjectionMiddleware,
|
MemoryInjectionMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.noop_injection import NoopInjectionMiddleware
|
||||||
|
from app.agents.new_chat.middleware.otel_span import OtelSpanMiddleware
|
||||||
|
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||||
|
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
||||||
|
from app.agents.new_chat.middleware.skills_backends import (
|
||||||
|
BuiltinSkillsBackend,
|
||||||
|
SearchSpaceSkillsBackend,
|
||||||
|
build_skills_backend_factory,
|
||||||
|
default_skills_sources,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.tool_call_repair import (
|
||||||
|
ToolCallNameRepairMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ActionLogMiddleware",
|
||||||
|
"AnonymousDocumentMiddleware",
|
||||||
|
"BuiltinSkillsBackend",
|
||||||
|
"BusyMutexMiddleware",
|
||||||
|
"ClearToolUsesEdit",
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
|
"DoomLoopMiddleware",
|
||||||
|
"FileIntentMiddleware",
|
||||||
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
"KnowledgeBaseSearchMiddleware",
|
||||||
|
"KnowledgePriorityMiddleware",
|
||||||
|
"KnowledgeTreeMiddleware",
|
||||||
"MemoryInjectionMiddleware",
|
"MemoryInjectionMiddleware",
|
||||||
|
"NoopInjectionMiddleware",
|
||||||
|
"OtelSpanMiddleware",
|
||||||
|
"PermissionMiddleware",
|
||||||
|
"RetryAfterMiddleware",
|
||||||
|
"SearchSpaceSkillsBackend",
|
||||||
|
"SpillToBackendEdit",
|
||||||
|
"SpillingContextEditingMiddleware",
|
||||||
|
"SurfSenseCompactionMiddleware",
|
||||||
"SurfSenseFilesystemMiddleware",
|
"SurfSenseFilesystemMiddleware",
|
||||||
|
"ToolCallNameRepairMiddleware",
|
||||||
|
"build_skills_backend_factory",
|
||||||
|
"commit_staged_filesystem_state",
|
||||||
|
"create_surfsense_compaction_middleware",
|
||||||
|
"default_skills_sources",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
363
surfsense_backend/app/agents/new_chat/middleware/action_log.py
Normal file
363
surfsense_backend/app/agents/new_chat/middleware/action_log.py
Normal file
|
|
@ -0,0 +1,363 @@
|
||||||
|
"""Append-only action-log middleware for the SurfSense agent.
|
||||||
|
|
||||||
|
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
||||||
|
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
||||||
|
into reversibility by declaring a ``reverse`` callable on their
|
||||||
|
:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered
|
||||||
|
descriptor is persisted in ``reverse_descriptor`` for use by
|
||||||
|
``/api/threads/{thread_id}/revert/{action_id}``.
|
||||||
|
|
||||||
|
Design points:
|
||||||
|
|
||||||
|
* **Defensive.** Logging never blocks the agent. We catch every exception
|
||||||
|
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
||||||
|
result is always returned untouched.
|
||||||
|
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
||||||
|
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
||||||
|
remains in the LangGraph checkpoint / spilled tool-output files.
|
||||||
|
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
||||||
|
with the parsed JSON result when the tool's content is a JSON object;
|
||||||
|
otherwise the raw text is passed. Exceptions in the reverse callable
|
||||||
|
are swallowed and logged — a failed descriptor render simply means the
|
||||||
|
action is NOT marked reversible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.callbacks import adispatch_custom_event
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
|
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||||
|
from langchain.agents.middleware.types import ToolCallRequest
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
|
||||||
|
# accidentally-huge inputs. Values are truncated and a flag is set in the
|
||||||
|
# stored payload so consumers can detect truncation.
|
||||||
|
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
|
||||||
|
|
||||||
|
|
||||||
|
class ActionLogMiddleware(AgentMiddleware):
|
||||||
|
"""Persist a row in :class:`AgentActionLog` after every tool call.
|
||||||
|
|
||||||
|
Should be placed near the OUTERMOST end of the tool-call wrapping stack
|
||||||
|
so that it sees the *final* :class:`ToolMessage` after all retries,
|
||||||
|
permission checks, and dedup logic have run. In practice that means
|
||||||
|
placing it just inside :class:`PermissionMiddleware` and outside
|
||||||
|
:class:`DedupHITLToolCallsMiddleware`.
|
||||||
|
|
||||||
|
The middleware is fully a no-op when:
|
||||||
|
|
||||||
|
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
|
||||||
|
(checked via :func:`get_flags`),
|
||||||
|
* the per-feature flag ``enable_action_log`` is off, or
|
||||||
|
* persistence raises (defensive: tool-call dispatch always succeeds).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The current chat thread's primary-key id. Required to
|
||||||
|
persist a row; if ``None`` the middleware silently no-ops.
|
||||||
|
search_space_id: Search-space id for cascade-on-delete safety.
|
||||||
|
user_id: UUID string of the user driving this turn (nullable in
|
||||||
|
anonymous mode).
|
||||||
|
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
|
||||||
|
so the middleware can look up the tool's ``reverse`` callable.
|
||||||
|
When omitted, no actions are marked reversible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: int | None,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
tool_definitions: dict[str, ToolDefinition] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._thread_id = thread_id
|
||||||
|
self._search_space_id = search_space_id
|
||||||
|
self._user_id = user_id
|
||||||
|
self._tool_definitions = dict(tool_definitions or {})
|
||||||
|
|
||||||
|
def _enabled(self) -> bool:
|
||||||
|
flags = get_flags()
|
||||||
|
if flags.disable_new_agent_stack:
|
||||||
|
return False
|
||||||
|
return bool(flags.enable_action_log) and self._thread_id is not None
|
||||||
|
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
if not self._enabled():
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
result: ToolMessage | Command[Any]
|
||||||
|
error_payload: dict[str, Any] | None = None
|
||||||
|
try:
|
||||||
|
result = await handler(request)
|
||||||
|
except Exception as exc:
|
||||||
|
# Persist the failure too so revert/audit can see it, then
|
||||||
|
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
|
||||||
|
error_payload = {"type": type(exc).__name__, "message": str(exc)}
|
||||||
|
await self._record(
|
||||||
|
request=request,
|
||||||
|
result=None,
|
||||||
|
error_payload=error_payload,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
await self._record(request=request, result=result, error_payload=None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _record(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
result: ToolMessage | Command[Any] | None,
|
||||||
|
error_payload: dict[str, Any] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
|
||||||
|
try:
|
||||||
|
from app.db import AgentActionLog, shielded_async_session
|
||||||
|
|
||||||
|
tool_name = _resolve_tool_name(request)
|
||||||
|
args_payload = _resolve_args_payload(request)
|
||||||
|
result_id = _resolve_result_id(result)
|
||||||
|
reverse_descriptor, reversible = self._render_reverse(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=_resolve_args_dict(request),
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call_id = _resolve_tool_call_id(request)
|
||||||
|
chat_turn_id = _resolve_chat_turn_id(request)
|
||||||
|
|
||||||
|
row = AgentActionLog(
|
||||||
|
thread_id=self._thread_id,
|
||||||
|
user_id=self._user_id,
|
||||||
|
search_space_id=self._search_space_id,
|
||||||
|
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||||
|
# kept for one release for safe rollback. New consumers
|
||||||
|
# should read ``tool_call_id`` directly.
|
||||||
|
turn_id=tool_call_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
|
message_id=_resolve_message_id(request),
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=args_payload,
|
||||||
|
result_id=result_id,
|
||||||
|
reversible=reversible,
|
||||||
|
reverse_descriptor=reverse_descriptor,
|
||||||
|
error=error_payload,
|
||||||
|
)
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
session.add(row)
|
||||||
|
await session.commit()
|
||||||
|
row_id = int(row.id) if row.id is not None else None
|
||||||
|
row_created_at = row.created_at
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"ActionLogMiddleware failed to persist action log row",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Surface a side-channel SSE event so the chat tool card can
|
||||||
|
# render a Revert button immediately after the row is durable.
|
||||||
|
# ``stream_new_chat`` translates this into a
|
||||||
|
# ``data-action-log`` SSE event. We DO NOT include the
|
||||||
|
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||||
|
try:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"action_log",
|
||||||
|
{
|
||||||
|
"id": row_id,
|
||||||
|
"lc_tool_call_id": tool_call_id,
|
||||||
|
"chat_turn_id": chat_turn_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"reversible": bool(reversible),
|
||||||
|
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||||
|
"created_at": row_created_at.isoformat()
|
||||||
|
if row_created_at
|
||||||
|
else None,
|
||||||
|
"error": error_payload is not None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"ActionLogMiddleware failed to dispatch action_log event",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _render_reverse(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any] | None,
|
||||||
|
result: ToolMessage | Command[Any] | None,
|
||||||
|
) -> tuple[dict[str, Any] | None, bool]:
|
||||||
|
"""Run the tool's ``reverse`` callable and return its descriptor.
|
||||||
|
|
||||||
|
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
|
||||||
|
the tool has no ``reverse`` callable, or when the callable raises,
|
||||||
|
the action is marked non-reversible.
|
||||||
|
"""
|
||||||
|
if not result or not isinstance(result, ToolMessage):
|
||||||
|
return None, False
|
||||||
|
if args is None:
|
||||||
|
return None, False
|
||||||
|
tool_def = self._tool_definitions.get(tool_name)
|
||||||
|
if tool_def is None or tool_def.reverse is None:
|
||||||
|
return None, False
|
||||||
|
try:
|
||||||
|
parsed_result = _parse_tool_result_content(result)
|
||||||
|
descriptor = tool_def.reverse(args, parsed_result)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Reverse descriptor render failed for tool %s",
|
||||||
|
tool_name,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None, False
|
||||||
|
if not isinstance(descriptor, dict):
|
||||||
|
return None, False
|
||||||
|
return descriptor, True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Resolution helpers — defensive against tool_call request shape variation.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tool_name(request: Any) -> str:
|
||||||
|
try:
|
||||||
|
tool = getattr(request, "tool", None)
|
||||||
|
if tool is not None:
|
||||||
|
name = getattr(tool, "name", None)
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
return name
|
||||||
|
call = getattr(request, "tool_call", None) or {}
|
||||||
|
if isinstance(call, dict):
|
||||||
|
name = call.get("name")
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
return name
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
pass
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
call = getattr(request, "tool_call", None)
|
||||||
|
if not isinstance(call, dict):
|
||||||
|
return None
|
||||||
|
args = call.get("args")
|
||||||
|
if isinstance(args, dict):
|
||||||
|
return args
|
||||||
|
return None
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
||||||
|
"""Return a JSON-serializable args dict, truncated if too big."""
|
||||||
|
args = _resolve_args_dict(request)
|
||||||
|
if args is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
encoded = json.dumps(args, default=str)
|
||||||
|
except Exception:
|
||||||
|
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
|
||||||
|
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
|
||||||
|
return args
|
||||||
|
return {
|
||||||
|
"_truncated": True,
|
||||||
|
"_size": len(encoded),
|
||||||
|
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||||
|
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||||
|
try:
|
||||||
|
call = getattr(request, "tool_call", None) or {}
|
||||||
|
if isinstance(call, dict):
|
||||||
|
tid = call.get("id")
|
||||||
|
if isinstance(tid, str):
|
||||||
|
return tid
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated alias kept for one release. Old callers and tests treated
|
||||||
|
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||||
|
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||||
|
_resolve_turn_id = _resolve_tool_call_id
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||||
|
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||||
|
|
||||||
|
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||||
|
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||||
|
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
runtime = getattr(request, "runtime", None)
|
||||||
|
if runtime is None:
|
||||||
|
return None
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return None
|
||||||
|
configurable = config.get("configurable")
|
||||||
|
if not isinstance(configurable, dict):
|
||||||
|
return None
|
||||||
|
value = configurable.get("turn_id")
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_message_id(request: Any) -> str | None:
|
||||||
|
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||||
|
return _resolve_tool_call_id(request)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_result_id(result: Any) -> str | None:
|
||||||
|
if isinstance(result, ToolMessage):
|
||||||
|
msg_id = getattr(result, "id", None)
|
||||||
|
if isinstance(msg_id, str):
|
||||||
|
return msg_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_result_content(result: ToolMessage) -> Any:
|
||||||
|
content = result.content
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return content
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ActionLogMiddleware"]
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Lightweight middleware that loads the anonymous-session document into state.
|
||||||
|
|
||||||
|
Anonymous chats receive a single uploaded document via Redis (no DB row,
|
||||||
|
read-only). This middleware loads it once on the first turn into
|
||||||
|
``state['kb_anon_doc']`` so:
|
||||||
|
|
||||||
|
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
|
||||||
|
view without touching the DB.
|
||||||
|
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
|
||||||
|
degenerate priority list.
|
||||||
|
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
|
||||||
|
recognises the synthetic path.
|
||||||
|
|
||||||
|
The middleware is a no-op when ``anon_session_id`` is not provided or when
|
||||||
|
the document is already cached in state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Load the anonymous user's uploaded document from Redis into state."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
|
def __init__(self, *, anon_session_id: str | None) -> None:
|
||||||
|
self.anon_session_id = anon_session_id
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
if not self.anon_session_id:
|
||||||
|
return None
|
||||||
|
if state.get("kb_anon_doc"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
anon_doc = await self._load_anon_document()
|
||||||
|
if anon_doc is None:
|
||||||
|
return None
|
||||||
|
return {"kb_anon_doc": anon_doc}
|
||||||
|
|
||||||
|
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||||
|
"""Read ``anon:doc:<session_id>`` from Redis."""
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis # local import to keep cold paths cheap
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
redis_client = aioredis.from_url(
|
||||||
|
config.REDIS_APP_URL, decode_responses=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||||
|
data = await redis_client.get(redis_key)
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
payload = json.loads(data)
|
||||||
|
finally:
|
||||||
|
await redis_client.aclose()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
title = str(payload.get("filename") or "uploaded_document")
|
||||||
|
content = str(payload.get("content") or "")
|
||||||
|
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
|
||||||
|
return {
|
||||||
|
"path": path,
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AnonymousDocumentMiddleware"]
|
||||||
311
surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
Normal file
311
surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
Normal file
|
|
@ -0,0 +1,311 @@
|
||||||
|
"""
|
||||||
|
BusyMutexMiddleware — per-thread asyncio lock + cancel token.
|
||||||
|
|
||||||
|
LangChain has no built-in concept of "this thread is already running a
|
||||||
|
turn — refuse the second concurrent request". Without it, a user
|
||||||
|
double-clicking "send" or refreshing the page mid-stream can spawn two
|
||||||
|
turns racing on the same checkpoint, producing duplicated tool calls
|
||||||
|
and mangled state.
|
||||||
|
|
||||||
|
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a
|
||||||
|
single-process, in-memory lock + cooperative cancellation token keyed by
|
||||||
|
``thread_id``. For multi-worker deployments a distributed lock backend
|
||||||
|
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
|
||||||
|
|
||||||
|
What this provides:
|
||||||
|
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
|
||||||
|
acquiring the lock during ``before_agent`` blocks any concurrent
|
||||||
|
prompt on the same thread until release.
|
||||||
|
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
|
||||||
|
tools can poll to abort cooperatively. The event is reset between
|
||||||
|
turns. Tools should check ``runtime.context.cancel_event.is_set()``
|
||||||
|
in tight inner loops.
|
||||||
|
- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a
|
||||||
|
second turn arrives while the lock is held.
|
||||||
|
|
||||||
|
Note: SurfSense's ``stream_new_chat`` is the call site that should
|
||||||
|
acquire/release. Wiring this as middleware means the contract is
|
||||||
|
explicit and the lock manager is shared with subagents that compile
|
||||||
|
their own ``create_agent`` runnables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langgraph.config import get_config
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import BusyError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _ThreadLockManager:
|
||||||
|
"""Process-local registry of per-thread asyncio locks + cancel events."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||||
|
weakref.WeakValueDictionary()
|
||||||
|
)
|
||||||
|
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||||
|
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||||
|
self._cancel_attempt_count: dict[str, int] = {}
|
||||||
|
# Monotonic per-thread epoch used to prevent stale middleware
|
||||||
|
# teardown from releasing a newer turn's lock.
|
||||||
|
self._turn_epoch: dict[str, int] = {}
|
||||||
|
|
||||||
|
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._locks[thread_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def cancel_event(self, thread_id: str) -> asyncio.Event:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
if event is None:
|
||||||
|
event = asyncio.Event()
|
||||||
|
self._cancel_events[thread_id] = event
|
||||||
|
return event
|
||||||
|
|
||||||
|
def request_cancel(self, thread_id: str) -> bool:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
if event is None:
|
||||||
|
event = asyncio.Event()
|
||||||
|
self._cancel_events[thread_id] = event
|
||||||
|
event.set()
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||||
|
self._cancel_attempt_count[thread_id] = (
|
||||||
|
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
return bool(event and event.is_set())
|
||||||
|
|
||||||
|
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||||
|
if not self.is_cancel_requested(thread_id):
|
||||||
|
return None
|
||||||
|
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||||
|
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||||
|
return attempts, requested_at_ms
|
||||||
|
|
||||||
|
def reset(self, thread_id: str) -> None:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
if event is not None:
|
||||||
|
event.clear()
|
||||||
|
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||||
|
self._cancel_attempt_count.pop(thread_id, None)
|
||||||
|
|
||||||
|
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||||
|
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||||
|
self._turn_epoch[thread_id] = epoch
|
||||||
|
return epoch
|
||||||
|
|
||||||
|
def current_turn_epoch(self, thread_id: str) -> int:
|
||||||
|
return self._turn_epoch.get(thread_id, 0)
|
||||||
|
|
||||||
|
def end_turn(self, thread_id: str) -> None:
|
||||||
|
"""Best-effort terminal cleanup for a thread turn.
|
||||||
|
|
||||||
|
This is intentionally idempotent and safe to call from outer stream
|
||||||
|
finally-blocks where middleware teardown might be skipped due to abort
|
||||||
|
or disconnect edge-cases.
|
||||||
|
"""
|
||||||
|
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||||
|
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||||
|
# retry that already acquired the lock for the same thread.
|
||||||
|
self.bump_turn_epoch(thread_id)
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is not None and lock.locked():
|
||||||
|
lock.release()
|
||||||
|
self.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — process-local but reused across all agent
|
||||||
|
# instances built in this process. Subagents created in nested
|
||||||
|
# ``create_agent`` calls also get this so locks are coherent.
|
||||||
|
manager = _ThreadLockManager()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||||
|
"""Public accessor used by long-running tools to poll cancellation."""
|
||||||
|
return manager.cancel_event(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def request_cancel(thread_id: str) -> bool:
|
||||||
|
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||||
|
return manager.request_cancel(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cancel_requested(thread_id: str) -> bool:
|
||||||
|
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||||
|
return manager.is_cancel_requested(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||||
|
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||||
|
return manager.cancel_state(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_cancel(thread_id: str) -> None:
|
||||||
|
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||||
|
manager.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def end_turn(thread_id: str) -> None:
|
||||||
|
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||||
|
manager.end_turn(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
|
"""Block concurrent prompts on the same thread.
|
||||||
|
|
||||||
|
Acquires the thread's lock in ``abefore_agent`` and releases in
|
||||||
|
``aafter_agent``. If the lock is held, raises :class:`BusyError`
|
||||||
|
so the caller can emit a ``surfsense.busy`` SSE event with the
|
||||||
|
in-flight request id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
require_thread_id: When True, raise :class:`BusyError` if no
|
||||||
|
``thread_id`` can be resolved from the active
|
||||||
|
``RunnableConfig``. Default is False — we treat a missing
|
||||||
|
thread_id as "this turn has nothing to lock against" and
|
||||||
|
no-op the mutex. Set True only when you trust the call
|
||||||
|
site to always provide ``configurable.thread_id`` (e.g.
|
||||||
|
in production where ``stream_new_chat`` always does).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, require_thread_id: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._require_thread_id = require_thread_id
|
||||||
|
self.tools = []
|
||||||
|
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||||
|
# only releases when its epoch still matches the manager's current
|
||||||
|
# epoch for the thread, preventing stale unlock races.
|
||||||
|
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||||
|
"""Extract ``thread_id`` from the active LangGraph ``RunnableConfig``.
|
||||||
|
|
||||||
|
``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``.
|
||||||
|
The runnable config (where ``configurable.thread_id`` lives) must be
|
||||||
|
fetched via :func:`langgraph.config.get_config` from inside a node /
|
||||||
|
middleware. We fall back to ``getattr(runtime, "config", None)`` for
|
||||||
|
unit tests / legacy runtimes that synthesize a config-bearing stub.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _from_dict(cfg: Any) -> str | None:
|
||||||
|
if not isinstance(cfg, dict):
|
||||||
|
return None
|
||||||
|
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||||
|
return str(tid) if tid is not None else None
|
||||||
|
|
||||||
|
# Preferred path: real LangGraph runtime context.
|
||||||
|
try:
|
||||||
|
tid = _from_dict(get_config())
|
||||||
|
except Exception:
|
||||||
|
tid = None
|
||||||
|
if tid is not None:
|
||||||
|
return tid
|
||||||
|
|
||||||
|
# Fallback for tests and any runtime that surfaces a config dict
|
||||||
|
# directly on the runtime instance.
|
||||||
|
return _from_dict(getattr(runtime, "config", None))
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState[Any],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del state
|
||||||
|
thread_id = self._thread_id(runtime)
|
||||||
|
if thread_id is None:
|
||||||
|
if self._require_thread_id:
|
||||||
|
raise BusyError("no thread_id configured")
|
||||||
|
logger.debug(
|
||||||
|
"BusyMutexMiddleware: no thread_id resolved from RunnableConfig; "
|
||||||
|
"skipping per-thread lock for this turn."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
lock = manager.lock_for(thread_id)
|
||||||
|
if lock.locked():
|
||||||
|
raise BusyError(request_id=thread_id)
|
||||||
|
await lock.acquire()
|
||||||
|
epoch = manager.bump_turn_epoch(thread_id)
|
||||||
|
self._held_locks[thread_id] = (lock, epoch)
|
||||||
|
# Reset the cancel event so this turn starts fresh
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aafter_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState[Any],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del state
|
||||||
|
thread_id = self._thread_id(runtime)
|
||||||
|
if thread_id is None:
|
||||||
|
return None
|
||||||
|
held = self._held_locks.pop(thread_id, None)
|
||||||
|
if held is None:
|
||||||
|
return None
|
||||||
|
lock, held_epoch = held
|
||||||
|
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||||
|
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||||
|
# already advanced epoch). Do not touch current lock/cancel state.
|
||||||
|
return None
|
||||||
|
if lock.locked():
|
||||||
|
lock.release()
|
||||||
|
# Always clear cancel event between turns so a stale signal
|
||||||
|
# doesn't leak into the next request.
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Provide sync no-ops because the middleware base class allows them
|
||||||
|
def before_agent( # type: ignore[override]
|
||||||
|
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
|
||||||
|
# if anyone else is in flight.
|
||||||
|
thread_id = self._thread_id(runtime)
|
||||||
|
if thread_id is None:
|
||||||
|
if self._require_thread_id:
|
||||||
|
raise BusyError("no thread_id configured")
|
||||||
|
return None
|
||||||
|
lock = manager.lock_for(thread_id)
|
||||||
|
if lock.locked():
|
||||||
|
raise BusyError(request_id=thread_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def after_agent( # type: ignore[override]
|
||||||
|
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BusyMutexMiddleware",
|
||||||
|
"end_turn",
|
||||||
|
"get_cancel_event",
|
||||||
|
"get_cancel_state",
|
||||||
|
"is_cancel_requested",
|
||||||
|
"manager",
|
||||||
|
"request_cancel",
|
||||||
|
"reset_cancel",
|
||||||
|
]
|
||||||
254
surfsense_backend/app/agents/new_chat/middleware/compaction.py
Normal file
254
surfsense_backend/app/agents/new_chat/middleware/compaction.py
Normal file
|
|
@ -0,0 +1,254 @@
|
||||||
|
"""
|
||||||
|
SurfSense compaction middleware.
|
||||||
|
|
||||||
|
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
|
||||||
|
to add SurfSense-specific behavior:
|
||||||
|
|
||||||
|
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
|
||||||
|
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``)
|
||||||
|
— see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base
|
||||||
|
``SummarizationMiddleware`` only ships a freeform "summarize this"
|
||||||
|
prompt; the structured template is ported from OpenCode's
|
||||||
|
``compaction.ts``.
|
||||||
|
2. **Protect SurfSense-specific SystemMessages** so injected hints
|
||||||
|
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
|
||||||
|
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
|
||||||
|
are *not* summarized away and are kept verbatim in the post-summary
|
||||||
|
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||||
|
(some message types are part of the agent's contract and must survive
|
||||||
|
compaction unchanged).
|
||||||
|
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
|
||||||
|
(Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage
|
||||||
|
containing only tool_calls and no text, ``content`` can be ``None`` and
|
||||||
|
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from deepagents.middleware.summarization import (
|
||||||
|
SummarizationMiddleware,
|
||||||
|
compute_summarization_defaults,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
|
||||||
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deepagents.backends.protocol import BACKEND_TYPES
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AnyMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Structured summary template ported from OpenCode's
|
||||||
|
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
|
||||||
|
# module-level constant so unit tests can assert on its sections.
|
||||||
|
SURFSENSE_SUMMARY_PROMPT = """<role>
|
||||||
|
SurfSense Conversation Compaction Assistant
|
||||||
|
</role>
|
||||||
|
|
||||||
|
<primary_objective>
|
||||||
|
Extract the most important context from the conversation history below into a structured summary that will replace the older messages.
|
||||||
|
</primary_objective>
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report.
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
What is the user's primary goal or request? State it in one or two sentences.
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)?
|
||||||
|
|
||||||
|
## Progress
|
||||||
|
What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion.
|
||||||
|
|
||||||
|
## Key Decisions
|
||||||
|
What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path.
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
What specific tasks remain to achieve the goal? Order them by dependency.
|
||||||
|
|
||||||
|
## Critical Context
|
||||||
|
What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name).
|
||||||
|
|
||||||
|
## Relevant Files
|
||||||
|
What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree.
|
||||||
|
</instructions>
|
||||||
|
|
||||||
|
<messages>
|
||||||
|
Messages to summarize:
|
||||||
|
{messages}
|
||||||
|
</messages>
|
||||||
|
|
||||||
|
Respond ONLY with the structured summary. Do not include any text before or after.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# SystemMessage prefixes that must NOT be summarized away. They are
|
||||||
|
# re-injected on every turn by the corresponding middleware, but the
|
||||||
|
# compaction step happens *before* re-injection in some paths, so we
|
||||||
|
# must preserve them verbatim across the cutoff.
|
||||||
|
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
||||||
|
"<priority_documents>", # KnowledgePriorityMiddleware
|
||||||
|
"<workspace_tree>", # KnowledgeTreeMiddleware
|
||||||
|
"<file_operation_contract>", # FileIntentMiddleware
|
||||||
|
"<user_memory>", # MemoryInjectionMiddleware
|
||||||
|
"<team_memory>", # MemoryInjectionMiddleware
|
||||||
|
"<user_name>", # MemoryInjectionMiddleware
|
||||||
|
"<memory_warning>", # MemoryInjectionMiddleware
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_protected_system_message(msg: AnyMessage) -> bool:
|
||||||
|
"""Return True if ``msg`` is a SystemMessage we must not summarize."""
|
||||||
|
if not isinstance(msg, SystemMessage):
|
||||||
|
return False
|
||||||
|
content = msg.content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
stripped = content.lstrip()
|
||||||
|
return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||||
|
"""Return ``msg`` with ``content=None`` coerced to ``""``.
|
||||||
|
|
||||||
|
Folds in the historical defense from ``safe_summarization.py`` —
|
||||||
|
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
|
||||||
|
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
|
||||||
|
AIMessage) explodes. We return a copy with empty string content so
|
||||||
|
downstream consumers see an empty body without mutating the original.
|
||||||
|
"""
|
||||||
|
if getattr(msg, "content", "not-missing") is not None:
|
||||||
|
return msg
|
||||||
|
try:
|
||||||
|
return msg.model_copy(update={"content": ""})
|
||||||
|
except AttributeError:
|
||||||
|
import copy
|
||||||
|
|
||||||
|
new_msg = copy.copy(msg)
|
||||||
|
try:
|
||||||
|
new_msg.content = ""
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"Could not sanitize content=None on message of type %s",
|
||||||
|
type(msg).__name__,
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
return new_msg
|
||||||
|
|
||||||
|
|
||||||
|
class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||||
|
"""SummarizationMiddleware tuned for SurfSense.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
- Overrides :meth:`_partition_messages` so protected SystemMessages
|
||||||
|
survive into the ``preserved_messages`` half regardless of cutoff.
|
||||||
|
- Overrides :meth:`_filter_summary_messages` so the buffer-string path
|
||||||
|
never iterates ``None`` content.
|
||||||
|
- Inherits everything else (auto-trigger, backend offload,
|
||||||
|
``_summarization_event`` plumbing, ``ContextOverflowError`` fallback).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _partition_messages( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
conversation_messages: list[AnyMessage],
|
||||||
|
cutoff_index: int,
|
||||||
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||||
|
"""Split messages but always preserve SurfSense protected SystemMessages.
|
||||||
|
|
||||||
|
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||||
|
(``opencode/packages/opencode/src/session/compaction.ts``): some
|
||||||
|
message types are always kept verbatim because they are part of the
|
||||||
|
agent's working contract, not transient output.
|
||||||
|
|
||||||
|
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
|
||||||
|
so dashboards can count compaction events and message-volume
|
||||||
|
without having to instrument upstream callers.
|
||||||
|
"""
|
||||||
|
# Opening a span here is appropriate because partitioning is the
|
||||||
|
# first call SummarizationMiddleware makes when it has decided to
|
||||||
|
# summarize; we record the volume and then close as a normal span.
|
||||||
|
with ot.compaction_span(
|
||||||
|
reason="auto",
|
||||||
|
messages_in=len(conversation_messages),
|
||||||
|
extra={"compaction.cutoff_index": int(cutoff_index)},
|
||||||
|
):
|
||||||
|
messages_to_summarize, preserved_messages = super()._partition_messages(
|
||||||
|
conversation_messages, cutoff_index
|
||||||
|
)
|
||||||
|
|
||||||
|
protected: list[AnyMessage] = []
|
||||||
|
kept_for_summary: list[AnyMessage] = []
|
||||||
|
for msg in messages_to_summarize:
|
||||||
|
if _is_protected_system_message(msg):
|
||||||
|
protected.append(msg)
|
||||||
|
else:
|
||||||
|
kept_for_summary.append(msg)
|
||||||
|
|
||||||
|
# Place protected blocks at the *front* of preserved_messages so
|
||||||
|
# they keep their original ordering relative to the summary
|
||||||
|
# HumanMessage that precedes the rest of the preserved tail.
|
||||||
|
return kept_for_summary, [*protected, *preserved_messages]
|
||||||
|
|
||||||
|
def _filter_summary_messages( # type: ignore[override]
|
||||||
|
self, messages: list[AnyMessage]
|
||||||
|
) -> list[AnyMessage]:
|
||||||
|
"""Filter previous summaries AND sanitize ``content=None``.
|
||||||
|
|
||||||
|
Folds the ``safe_summarization.py`` defense in: when the buffer
|
||||||
|
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
|
||||||
|
here covers both the sync and async offload paths.
|
||||||
|
"""
|
||||||
|
filtered = super()._filter_summary_messages(messages)
|
||||||
|
return [_sanitize_message_content(m) for m in filtered]
|
||||||
|
|
||||||
|
|
||||||
|
def create_surfsense_compaction_middleware(
|
||||||
|
model: BaseChatModel,
|
||||||
|
backend: BACKEND_TYPES,
|
||||||
|
*,
|
||||||
|
summary_prompt: str | None = None,
|
||||||
|
history_path_prefix: str = "/conversation_history",
|
||||||
|
**overrides: Any,
|
||||||
|
) -> SurfSenseCompactionMiddleware:
|
||||||
|
"""Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults.
|
||||||
|
|
||||||
|
Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings``
|
||||||
|
via :func:`deepagents.middleware.summarization.compute_summarization_defaults`
|
||||||
|
so callers get the same behavior as ``create_summarization_middleware``
|
||||||
|
plus our overrides.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Chat model to call for summary generation.
|
||||||
|
backend: Backend instance or factory for offloading conversation history.
|
||||||
|
summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`.
|
||||||
|
history_path_prefix: Path prefix for offloaded conversation history.
|
||||||
|
**overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`.
|
||||||
|
"""
|
||||||
|
defaults = compute_summarization_defaults(model)
|
||||||
|
return SurfSenseCompactionMiddleware(
|
||||||
|
model=model,
|
||||||
|
backend=backend,
|
||||||
|
trigger=overrides.pop("trigger", defaults["trigger"]),
|
||||||
|
keep=overrides.pop("keep", defaults["keep"]),
|
||||||
|
trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None),
|
||||||
|
truncate_args_settings=overrides.pop(
|
||||||
|
"truncate_args_settings", defaults["truncate_args_settings"]
|
||||||
|
),
|
||||||
|
summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT,
|
||||||
|
history_path_prefix=history_path_prefix,
|
||||||
|
**overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PROTECTED_SYSTEM_PREFIXES",
|
||||||
|
"SURFSENSE_SUMMARY_PROMPT",
|
||||||
|
"SurfSenseCompactionMiddleware",
|
||||||
|
"create_surfsense_compaction_middleware",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,350 @@
|
||||||
|
"""
|
||||||
|
SpillToBackendEdit + SpillingContextEditingMiddleware.
|
||||||
|
|
||||||
|
LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content``
|
||||||
|
when the context-editing budget triggers, replacing the body with a fixed
|
||||||
|
placeholder. That's lossy: anything the agent might want to revisit is
|
||||||
|
gone. The spill-to-disk pattern (originally from OpenCode's
|
||||||
|
``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune
|
||||||
|
behavior but writes the full original payload to the runtime backend
|
||||||
|
under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The
|
||||||
|
placeholder is then upgraded to point at the spill path so the agent
|
||||||
|
(or a subagent) can read it back on demand.
|
||||||
|
|
||||||
|
Why this is a middleware subclass instead of a plain ``ContextEdit``:
|
||||||
|
``ContextEdit.apply`` is sync, but writing to the backend is async. We
|
||||||
|
capture the spill payloads inside ``apply`` and flush them via
|
||||||
|
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
|
||||||
|
delegating to the handler, so the explore subagent can always read what
|
||||||
|
the placeholder advertises.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.context_editing import (
|
||||||
|
ClearToolUsesEdit,
|
||||||
|
ContextEdit,
|
||||||
|
ContextEditingMiddleware,
|
||||||
|
TokenCounter,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AnyMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.utils import count_tokens_approximately
|
||||||
|
from langgraph.config import get_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deepagents.backends.protocol import BackendProtocol
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_spill_placeholder(spill_path: str) -> str:
|
||||||
|
"""Build the user-facing placeholder text shown to the model."""
|
||||||
|
return (
|
||||||
|
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_thread_id_or_session() -> str:
|
||||||
|
"""Best-effort thread_id discovery for the spill path.
|
||||||
|
|
||||||
|
Falls back to a process-stable string if no LangGraph config is
|
||||||
|
available (e.g. unit tests). The exact value doesn't matter as long
|
||||||
|
as it's stable within one stream so the placeholder paths line up
|
||||||
|
with the actual upload path.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config = get_config()
|
||||||
|
thread_id = config.get("configurable", {}).get("thread_id")
|
||||||
|
if thread_id is not None:
|
||||||
|
return str(thread_id)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
return "no_thread"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class SpillToBackendEdit(ContextEdit):
|
||||||
|
"""Capture-and-replace context edit that spills full tool output to the backend.
|
||||||
|
|
||||||
|
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
|
||||||
|
semantics) **and** records the original ``ToolMessage.content`` in
|
||||||
|
:attr:`pending_spills` so the wrapping middleware can flush them
|
||||||
|
before the model call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger: Token threshold above which the edit fires.
|
||||||
|
clear_at_least: Minimum number of tokens to reclaim (best effort).
|
||||||
|
keep: Number of most-recent ``ToolMessage`` instances to leave
|
||||||
|
untouched.
|
||||||
|
exclude_tools: Names of tools whose output is NOT spilled.
|
||||||
|
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
|
||||||
|
args when their pair is cleared.
|
||||||
|
path_prefix: Path under the backend where spills are written.
|
||||||
|
Default ``"/tool_outputs"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trigger: int = 100_000
|
||||||
|
clear_at_least: int = 0
|
||||||
|
keep: int = 3
|
||||||
|
clear_tool_inputs: bool = False
|
||||||
|
exclude_tools: Sequence[str] = ()
|
||||||
|
path_prefix: str = DEFAULT_SPILL_PREFIX
|
||||||
|
|
||||||
|
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
|
||||||
|
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||||
|
|
||||||
|
def drain_pending(self) -> list[tuple[str, bytes]]:
|
||||||
|
"""Return and clear the pending-spill list atomically."""
|
||||||
|
with self._lock:
|
||||||
|
out = list(self.pending_spills)
|
||||||
|
self.pending_spills.clear()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
messages: list[AnyMessage],
|
||||||
|
*,
|
||||||
|
count_tokens: TokenCounter,
|
||||||
|
) -> None:
|
||||||
|
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
|
||||||
|
tokens = count_tokens(messages)
|
||||||
|
if tokens <= self.trigger:
|
||||||
|
return
|
||||||
|
|
||||||
|
candidates = [
|
||||||
|
(idx, msg)
|
||||||
|
for idx, msg in enumerate(messages)
|
||||||
|
if isinstance(msg, ToolMessage)
|
||||||
|
]
|
||||||
|
if self.keep >= len(candidates):
|
||||||
|
return
|
||||||
|
if self.keep:
|
||||||
|
candidates = candidates[: -self.keep]
|
||||||
|
|
||||||
|
thread_id = _get_thread_id_or_session()
|
||||||
|
excluded_tools = set(self.exclude_tools)
|
||||||
|
|
||||||
|
for idx, tool_message in candidates:
|
||||||
|
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
ai_message = next(
|
||||||
|
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if ai_message is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_call = next(
|
||||||
|
(
|
||||||
|
call
|
||||||
|
for call in ai_message.tool_calls
|
||||||
|
if call.get("id") == tool_message.tool_call_id
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if tool_call is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_name = tool_message.name or tool_call["name"]
|
||||||
|
if tool_name in excluded_tools:
|
||||||
|
continue
|
||||||
|
|
||||||
|
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
|
||||||
|
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
|
||||||
|
|
||||||
|
original = tool_message.content
|
||||||
|
payload = self._encode_payload(original)
|
||||||
|
with self._lock:
|
||||||
|
self.pending_spills.append((spill_path, payload))
|
||||||
|
|
||||||
|
messages[idx] = tool_message.model_copy(
|
||||||
|
update={
|
||||||
|
"artifact": None,
|
||||||
|
"content": _build_spill_placeholder(spill_path),
|
||||||
|
"response_metadata": {
|
||||||
|
**tool_message.response_metadata,
|
||||||
|
"context_editing": {
|
||||||
|
"cleared": True,
|
||||||
|
"strategy": "spill_to_backend",
|
||||||
|
"spill_path": spill_path,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clear_tool_inputs:
|
||||||
|
ai_idx = messages.index(ai_message)
|
||||||
|
messages[ai_idx] = self._clear_input_args(
|
||||||
|
ai_message, tool_message.tool_call_id or ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clear_at_least > 0:
|
||||||
|
new_token_count = count_tokens(messages)
|
||||||
|
cleared_tokens = max(0, tokens - new_token_count)
|
||||||
|
if cleared_tokens >= self.clear_at_least:
|
||||||
|
break
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_payload(content: Any) -> bytes:
|
||||||
|
"""Serialize ``ToolMessage.content`` to bytes for upload."""
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
return content
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content.encode("utf-8")
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
return json.dumps(content, default=str).encode("utf-8")
|
||||||
|
except Exception:
|
||||||
|
return str(content).encode("utf-8")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
|
||||||
|
updated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
cleared_any = False
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
updated = dict(tool_call)
|
||||||
|
if updated.get("id") == tool_call_id:
|
||||||
|
updated["args"] = {}
|
||||||
|
cleared_any = True
|
||||||
|
updated_tool_calls.append(updated)
|
||||||
|
|
||||||
|
metadata = dict(getattr(message, "response_metadata", {}))
|
||||||
|
if cleared_any:
|
||||||
|
ctx = dict(metadata.get("context_editing", {}))
|
||||||
|
ids = set(ctx.get("cleared_tool_inputs", []))
|
||||||
|
ids.add(tool_call_id)
|
||||||
|
ctx["cleared_tool_inputs"] = sorted(ids)
|
||||||
|
metadata["context_editing"] = ctx
|
||||||
|
return message.model_copy(
|
||||||
|
update={
|
||||||
|
"tool_calls": updated_tool_calls,
|
||||||
|
"response_metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
|
||||||
|
|
||||||
|
|
||||||
|
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
|
||||||
|
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
|
||||||
|
|
||||||
|
Runs the configured edits as the parent does, then flushes any
|
||||||
|
pending spills via the supplied backend resolver before delegating
|
||||||
|
to the model handler. Spill failures are logged but never abort the
|
||||||
|
model call — the placeholder text is already in the message, so the
|
||||||
|
worst case is the agent gets a placeholder it cannot follow up on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
edits: Sequence[ContextEdit],
|
||||||
|
backend_resolver: BackendResolver | None = None,
|
||||||
|
token_count_method: str = "approximate",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
|
||||||
|
self._backend_resolver = backend_resolver
|
||||||
|
|
||||||
|
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
|
||||||
|
if self._backend_resolver is None:
|
||||||
|
return None
|
||||||
|
if callable(self._backend_resolver):
|
||||||
|
try:
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
|
tool_runtime = ToolRuntime(
|
||||||
|
state=getattr(request, "state", {}),
|
||||||
|
context=getattr(request.runtime, "context", None),
|
||||||
|
stream_writer=getattr(request.runtime, "stream_writer", None),
|
||||||
|
store=getattr(request.runtime, "store", None),
|
||||||
|
config=getattr(request.runtime, "config", None) or {},
|
||||||
|
tool_call_id=None,
|
||||||
|
)
|
||||||
|
return self._backend_resolver(tool_runtime)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to resolve spill backend")
|
||||||
|
return None
|
||||||
|
return self._backend_resolver # type: ignore[return-value]
|
||||||
|
|
||||||
|
def _collect_pending(self) -> list[tuple[str, bytes]]:
|
||||||
|
out: list[tuple[str, bytes]] = []
|
||||||
|
for edit in self.edits:
|
||||||
|
if isinstance(edit, SpillToBackendEdit):
|
||||||
|
out.extend(edit.drain_pending())
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> Any:
|
||||||
|
if not request.messages:
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
if self.token_count_method == "approximate":
|
||||||
|
|
||||||
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
|
return count_tokens_approximately(messages)
|
||||||
|
|
||||||
|
else:
|
||||||
|
system_msg = [request.system_message] if request.system_message else []
|
||||||
|
|
||||||
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
|
return request.model.get_num_tokens_from_messages(
|
||||||
|
system_msg + list(messages), request.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
edited_messages = deepcopy(list(request.messages))
|
||||||
|
for edit in self.edits:
|
||||||
|
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||||
|
|
||||||
|
pending = self._collect_pending()
|
||||||
|
if pending:
|
||||||
|
backend = self._resolve_backend(request)
|
||||||
|
if backend is not None:
|
||||||
|
try:
|
||||||
|
await backend.aupload_files(pending)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Spill-to-backend upload failed (%d files); placeholders "
|
||||||
|
"remain in messages but content is unrecoverable",
|
||||||
|
len(pending),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"SpillToBackendEdit produced %d pending spills but no backend "
|
||||||
|
"resolver was configured; content is unrecoverable",
|
||||||
|
len(pending),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await handler(request.override(messages=edited_messages))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_SPILL_PREFIX",
|
||||||
|
"ClearToolUsesEdit",
|
||||||
|
"SpillToBackendEdit",
|
||||||
|
"SpillingContextEditingMiddleware",
|
||||||
|
"_build_spill_placeholder",
|
||||||
|
]
|
||||||
|
|
@ -2,17 +2,27 @@
|
||||||
|
|
||||||
When the LLM emits multiple calls to the same HITL tool with the same
|
When the LLM emits multiple calls to the same HITL tool with the same
|
||||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||||
only the first call is kept. Non-HITL tools are never touched.
|
only the first call is kept. Non-HITL tools are never touched.
|
||||||
|
|
||||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||||
the removed call will never appear on graph resume.
|
the removed call will never appear on graph resume.
|
||||||
|
|
||||||
|
Dedup-key resolution order:
|
||||||
|
|
||||||
|
1. :class:`ToolDefinition.dedup_key` — callable provided by the registry
|
||||||
|
entry. This is the canonical mechanism.
|
||||||
|
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name;
|
||||||
|
used by MCP / Composio tools whose schemas the registry doesn't see.
|
||||||
|
|
||||||
|
A tool with no resolver from either path simply opts out of dedup.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
|
@ -20,81 +30,83 @@ from langgraph.runtime import Runtime
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
# Resolver type — given the tool ``args`` dict returns a stable
|
||||||
# Gmail
|
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||||
"send_gmail_email": "subject",
|
DedupResolver = Callable[[dict[str, Any]], str]
|
||||||
"create_gmail_draft": "subject",
|
|
||||||
"update_gmail_draft": "draft_subject_or_id",
|
|
||||||
"trash_gmail_email": "email_subject_or_id",
|
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
||||||
# Google Calendar
|
"""Adapt a string-arg name into a :data:`DedupResolver`.
|
||||||
"create_calendar_event": "title",
|
|
||||||
"update_calendar_event": "event_title_or_id",
|
Convenience helper used by registry entries that just want to dedupe
|
||||||
"delete_calendar_event": "event_title_or_id",
|
on a single arg's lowercased value (the most common case for native
|
||||||
# Google Drive
|
HITL tools like ``send_gmail_email`` keyed on ``subject``).
|
||||||
"create_google_drive_file": "file_name",
|
|
||||||
"delete_google_drive_file": "file_name",
|
Example::
|
||||||
# OneDrive
|
|
||||||
"create_onedrive_file": "file_name",
|
ToolDefinition(
|
||||||
"delete_onedrive_file": "file_name",
|
name="send_gmail_email",
|
||||||
# Dropbox
|
...,
|
||||||
"create_dropbox_file": "file_name",
|
dedup_key=wrap_dedup_key_by_arg_name("subject"),
|
||||||
"delete_dropbox_file": "file_name",
|
)
|
||||||
# Notion
|
"""
|
||||||
"create_notion_page": "title",
|
|
||||||
"update_notion_page": "page_title",
|
def _resolver(args: dict[str, Any]) -> str:
|
||||||
"delete_notion_page": "page_title",
|
return str(args.get(arg_name, "")).lower()
|
||||||
# Linear
|
|
||||||
"create_linear_issue": "title",
|
return _resolver
|
||||||
"update_linear_issue": "issue_ref",
|
|
||||||
"delete_linear_issue": "issue_ref",
|
|
||||||
# Jira
|
# Backwards-compatible alias for code that imported the original
|
||||||
"create_jira_issue": "summary",
|
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||||
"update_jira_issue": "issue_title_or_key",
|
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||||
"delete_jira_issue": "issue_title_or_key",
|
|
||||||
# Confluence
|
|
||||||
"create_confluence_page": "title",
|
|
||||||
"update_confluence_page": "page_title_or_id",
|
|
||||||
"delete_confluence_page": "page_title_or_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||||
|
|
||||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
||||||
pair is kept; subsequent duplicates are silently dropped.
|
pair is kept; subsequent duplicates are silently dropped.
|
||||||
|
|
||||||
The dedup map is built from two sources:
|
The dedup-resolver map is built from two sources, in priority order:
|
||||||
|
|
||||||
1. A comprehensive list of native HITL tools (hardcoded above).
|
1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's
|
||||||
2. Any ``StructuredTool`` instances passed via *agent_tools* whose
|
``ToolDefinition.dedup_key``. Receives the args dict and returns
|
||||||
``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``.
|
a string signature. This is the canonical mechanism.
|
||||||
This is how MCP tools automatically get dedup support.
|
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||||
|
name; primarily used by MCP / Composio tools.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tools = ()
|
tools = ()
|
||||||
|
|
||||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||||
self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS)
|
self._resolvers: dict[str, DedupResolver] = {}
|
||||||
|
|
||||||
for t in agent_tools or []:
|
for t in agent_tools or []:
|
||||||
meta = getattr(t, "metadata", None) or {}
|
meta = getattr(t, "metadata", None) or {}
|
||||||
|
callable_key = meta.get("dedup_key")
|
||||||
|
if callable(callable_key):
|
||||||
|
self._resolvers[t.name] = callable_key
|
||||||
|
continue
|
||||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||||
self._dedup_keys[t.name] = meta["hitl_dedup_key"]
|
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
||||||
|
meta["hitl_dedup_key"]
|
||||||
|
)
|
||||||
|
|
||||||
def after_model(
|
def after_model(
|
||||||
self, state: AgentState, runtime: Runtime[Any]
|
self, state: AgentState, runtime: Runtime[Any]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
return self._dedup(state, self._dedup_keys)
|
return self._dedup(state, self._resolvers)
|
||||||
|
|
||||||
async def aafter_model(
|
async def aafter_model(
|
||||||
self, state: AgentState, runtime: Runtime[Any]
|
self, state: AgentState, runtime: Runtime[Any]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
return self._dedup(state, self._dedup_keys)
|
return self._dedup(state, self._resolvers)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dedup(
|
def _dedup(
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
dedup_keys: dict[str, str], # type: ignore[type-arg]
|
resolvers: dict[str, DedupResolver],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
messages = state.get("messages")
|
messages = state.get("messages")
|
||||||
if not messages:
|
if not messages:
|
||||||
|
|
@ -110,9 +122,16 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
name = tc.get("name", "")
|
name = tc.get("name", "")
|
||||||
dedup_key_arg = dedup_keys.get(name)
|
resolver = resolvers.get(name)
|
||||||
if dedup_key_arg is not None:
|
if resolver is not None:
|
||||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
try:
|
||||||
|
arg_val = resolver(tc.get("args", {}) or {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Dedup resolver for tool %s raised; keeping call", name
|
||||||
|
)
|
||||||
|
deduped.append(tc)
|
||||||
|
continue
|
||||||
key = (name, arg_val)
|
key = (name, arg_val)
|
||||||
if key in seen:
|
if key in seen:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
237
surfsense_backend/app/agents/new_chat/middleware/doom_loop.py
Normal file
237
surfsense_backend/app/agents/new_chat/middleware/doom_loop.py
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
"""
|
||||||
|
DoomLoopMiddleware — pattern-based detector for repeated identical tool calls.
|
||||||
|
|
||||||
|
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
|
||||||
|
of tool calls per turn — but it can't tell apart "10 distinct, useful
|
||||||
|
calls" from "the same call 10 times in a row". This middleware fills that
|
||||||
|
gap with a sliding-window check on tool-call signatures, ported from
|
||||||
|
OpenCode's ``packages/opencode/src/session/processor.ts``.
|
||||||
|
|
||||||
|
When the same tool with the same arguments is called N times in a row,
|
||||||
|
the agent has likely entered an infinite loop. We surface this to the
|
||||||
|
user as an interrupt with ``permission="doom_loop"`` so the UI can
|
||||||
|
render an "Are you stuck? Continue / cancel?" affordance.
|
||||||
|
|
||||||
|
This ships **OFF by default** until the frontend explicitly handles
|
||||||
|
``context.permission == "doom_loop"`` interrupts.
|
||||||
|
|
||||||
|
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
|
||||||
|
(see ``app/agents/new_chat/tools/hitl.py``):
|
||||||
|
|
||||||
|
{
|
||||||
|
"type": "permission_ask",
|
||||||
|
"action": {"tool": <name>, "params": <args>},
|
||||||
|
"context": {"permission": "doom_loop", "recent_signatures": [...]},
|
||||||
|
}
|
||||||
|
|
||||||
|
so the frontend that already handles HITL prompts can render this with
|
||||||
|
no changes beyond a string check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langgraph.config import get_config
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _signature(name: str, args: Any) -> str:
|
||||||
|
"""Hash a tool call ``(name, args)`` to a short signature."""
|
||||||
|
try:
|
||||||
|
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
canonical = repr(args)
|
||||||
|
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
|
||||||
|
return digest[:16]
|
||||||
|
|
||||||
|
|
||||||
|
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
|
"""Detect repeated identical tool calls and prompt the user.
|
||||||
|
|
||||||
|
Tracks a sliding window of the most-recent ``threshold`` tool-call
|
||||||
|
signatures across the live request. When all entries match, raise
|
||||||
|
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: How many consecutive identical signatures count as a
|
||||||
|
doom loop. Default 3 (matches OpenCode's processor.ts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, threshold: int = 3) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if threshold < 2:
|
||||||
|
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
|
||||||
|
self._threshold = threshold
|
||||||
|
self.tools = []
|
||||||
|
# Per-thread sliding windows. We can't put this in graph state
|
||||||
|
# without state-schema gymnastics; for one process-lifetime it's
|
||||||
|
# fine to keep an in-memory map keyed by thread_id.
|
||||||
|
self._windows: dict[str, deque[str]] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
|
||||||
|
"""Resolve the thread id for sliding-window keying.
|
||||||
|
|
||||||
|
Prefer LangGraph's ``get_config()`` (the only way to read
|
||||||
|
``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry
|
||||||
|
a ``config`` attribute). Fall back to ``runtime.config`` for unit
|
||||||
|
tests that synthesize a config-bearing stub. Default
|
||||||
|
``"no_thread"`` is intentionally only used when both lookups fail
|
||||||
|
— it would collapse all threads into one window so we keep the
|
||||||
|
debug log loud.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _from_dict(cfg: Any) -> str | None:
|
||||||
|
if not isinstance(cfg, dict):
|
||||||
|
return None
|
||||||
|
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||||
|
return str(tid) if tid is not None else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
tid = _from_dict(get_config())
|
||||||
|
except Exception:
|
||||||
|
tid = None
|
||||||
|
if tid is not None:
|
||||||
|
return tid
|
||||||
|
|
||||||
|
tid = _from_dict(getattr(runtime, "config", None))
|
||||||
|
if tid is not None:
|
||||||
|
return tid
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
|
||||||
|
"falling back to shared 'no_thread' window."
|
||||||
|
)
|
||||||
|
return "no_thread"
|
||||||
|
|
||||||
|
def _window(self, thread_id: str) -> deque[str]:
|
||||||
|
win = self._windows.get(thread_id)
|
||||||
|
if win is None:
|
||||||
|
win = deque(maxlen=self._threshold)
|
||||||
|
self._windows[thread_id] = win
|
||||||
|
return win
|
||||||
|
|
||||||
|
def _detect(
|
||||||
|
self, message: AIMessage, runtime: Runtime[ContextT]
|
||||||
|
) -> tuple[bool, list[str], dict[str, Any] | None]:
|
||||||
|
if not message.tool_calls:
|
||||||
|
return False, [], None
|
||||||
|
|
||||||
|
thread_id = self._thread_id_from_runtime(runtime)
|
||||||
|
window = self._window(thread_id)
|
||||||
|
|
||||||
|
triggered_call: dict[str, Any] | None = None
|
||||||
|
for call in message.tool_calls:
|
||||||
|
name = (
|
||||||
|
call.get("name")
|
||||||
|
if isinstance(call, dict)
|
||||||
|
else getattr(call, "name", None)
|
||||||
|
)
|
||||||
|
args = (
|
||||||
|
call.get("args")
|
||||||
|
if isinstance(call, dict)
|
||||||
|
else getattr(call, "args", {})
|
||||||
|
)
|
||||||
|
if not isinstance(name, str):
|
||||||
|
continue
|
||||||
|
sig = _signature(name, args)
|
||||||
|
window.append(sig)
|
||||||
|
if len(window) >= self._threshold and len(set(window)) == 1:
|
||||||
|
triggered_call = {"name": name, "params": args or {}}
|
||||||
|
break
|
||||||
|
|
||||||
|
if triggered_call is None:
|
||||||
|
return False, list(window), None
|
||||||
|
return True, list(window), triggered_call
|
||||||
|
|
||||||
|
def after_model( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState[ResponseT],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
last = messages[-1]
|
||||||
|
if not isinstance(last, AIMessage):
|
||||||
|
return None
|
||||||
|
|
||||||
|
triggered, signatures, action = self._detect(last, runtime)
|
||||||
|
if not triggered:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
|
||||||
|
action["name"] if action else "<unknown>",
|
||||||
|
self._threshold,
|
||||||
|
signatures[-1] if signatures else "<empty>",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Open an interrupt.raised span with permission=doom_loop attribute
|
||||||
|
# so dashboards can break out doom-loop interrupts from regular
|
||||||
|
# permission asks via the ``interrupt.permission`` attribute.
|
||||||
|
with ot.interrupt_span(
|
||||||
|
interrupt_type="permission_ask",
|
||||||
|
extra={
|
||||||
|
"interrupt.permission": "doom_loop",
|
||||||
|
"interrupt.threshold": self._threshold,
|
||||||
|
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
|
||||||
|
},
|
||||||
|
):
|
||||||
|
decision = interrupt(
|
||||||
|
{
|
||||||
|
"type": "permission_ask",
|
||||||
|
"action": action or {"tool": "<unknown>", "params": {}},
|
||||||
|
"context": {
|
||||||
|
"permission": "doom_loop",
|
||||||
|
"recent_signatures": signatures,
|
||||||
|
"threshold": self._threshold,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset window so the next decision (continue/cancel) starts fresh.
|
||||||
|
thread_id = self._thread_id_from_runtime(runtime)
|
||||||
|
self._windows.pop(thread_id, None)
|
||||||
|
|
||||||
|
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
|
||||||
|
# If the user cancelled, jump to end. Otherwise return ``None`` so the
|
||||||
|
# tool call proceeds. The frontend's exact reply names may differ —
|
||||||
|
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||||
|
if isinstance(decision, dict):
|
||||||
|
kind = str(
|
||||||
|
decision.get("decision_type") or decision.get("type") or ""
|
||||||
|
).lower()
|
||||||
|
if "reject" in kind or "cancel" in kind:
|
||||||
|
return {"jump_to": "end"}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aafter_model( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState[ResponseT],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self.after_model(state, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DoomLoopMiddleware",
|
||||||
|
"_signature",
|
||||||
|
]
|
||||||
334
surfsense_backend/app/agents/new_chat/middleware/file_intent.py
Normal file
334
surfsense_backend/app/agents/new_chat/middleware/file_intent.py
Normal file
|
|
@ -0,0 +1,334 @@
|
||||||
|
"""Semantic file-intent routing middleware for new chat turns.
|
||||||
|
|
||||||
|
This middleware classifies the latest human turn into a small intent set:
|
||||||
|
- chat_only
|
||||||
|
- file_write
|
||||||
|
- file_read
|
||||||
|
|
||||||
|
For ``file_write`` turns it injects a strict system contract so the model
|
||||||
|
uses filesystem tools before claiming success, and provides a deterministic
|
||||||
|
fallback path when no filename is specified by the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileOperationIntent(StrEnum):
|
||||||
|
CHAT_ONLY = "chat_only"
|
||||||
|
FILE_WRITE = "file_write"
|
||||||
|
FILE_READ = "file_read"
|
||||||
|
|
||||||
|
|
||||||
|
class FileIntentPlan(BaseModel):
|
||||||
|
intent: FileOperationIntent = Field(
|
||||||
|
description="Primary user intent for this turn."
|
||||||
|
)
|
||||||
|
confidence: float = Field(
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
default=0.5,
|
||||||
|
description="Model confidence in the selected intent.",
|
||||||
|
)
|
||||||
|
suggested_filename: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional filename (e.g. notes.md) inferred from user request.",
|
||||||
|
)
|
||||||
|
suggested_directory: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
|
||||||
|
"user request."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
suggested_path: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
|
||||||
|
"takes precedence over suggested_directory + suggested_filename."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
|
content = getattr(message, "content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
parts.append(str(item.get("text", "")))
|
||||||
|
return "\n".join(part for part in parts if part)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_payload(text: str) -> str:
|
||||||
|
stripped = text.strip()
|
||||||
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||||
|
if fenced:
|
||||||
|
return fenced.group(1)
|
||||||
|
start = stripped.find("{")
|
||||||
|
end = stripped.rfind("}")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
return stripped[start : end + 1]
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename(value: str) -> str:
|
||||||
|
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||||
|
name = re.sub(r"\s+", "-", name)
|
||||||
|
name = name.strip("._-")
|
||||||
|
if not name:
|
||||||
|
name = "note"
|
||||||
|
if len(name) > 80:
|
||||||
|
name = name[:80].rstrip("-_.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_path_segment(value: str) -> str:
|
||||||
|
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||||
|
segment = re.sub(r"\s+", "_", segment)
|
||||||
|
segment = segment.strip("._-")
|
||||||
|
return segment
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_directory(value: str) -> str:
|
||||||
|
raw = value.strip().replace("\\", "/")
|
||||||
|
raw = raw.strip("/")
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||||
|
parts = [part for part in parts if part]
|
||||||
|
return "/".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_file_path(value: str) -> str:
|
||||||
|
raw = value.strip().replace("\\", "/").strip()
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
had_trailing_slash = raw.endswith("/")
|
||||||
|
raw = raw.strip("/")
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||||
|
parts = [part for part in parts if part]
|
||||||
|
if not parts:
|
||||||
|
return ""
|
||||||
|
if had_trailing_slash:
|
||||||
|
return f"/{'/'.join(parts)}/"
|
||||||
|
return f"/{'/'.join(parts)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_directory_from_user_text(user_text: str) -> str | None:
|
||||||
|
patterns = (
|
||||||
|
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
|
||||||
|
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
|
||||||
|
)
|
||||||
|
lowered = user_text.lower()
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, lowered, flags=re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
candidate = match.group(1).strip()
|
||||||
|
if candidate in {"the", "a", "an"}:
|
||||||
|
continue
|
||||||
|
normalized = _normalize_directory(candidate)
|
||||||
|
if normalized:
|
||||||
|
return normalized
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_path(
|
||||||
|
suggested_filename: str | None,
|
||||||
|
*,
|
||||||
|
suggested_directory: str | None = None,
|
||||||
|
suggested_path: str | None = None,
|
||||||
|
user_text: str,
|
||||||
|
) -> str:
|
||||||
|
inferred_dir = _infer_directory_from_user_text(user_text)
|
||||||
|
|
||||||
|
sanitized_filename = ""
|
||||||
|
if suggested_filename:
|
||||||
|
sanitized_filename = _sanitize_filename(suggested_filename)
|
||||||
|
if sanitized_filename.lower().endswith(".txt"):
|
||||||
|
sanitized_filename = f"{sanitized_filename[:-4]}.md"
|
||||||
|
if not sanitized_filename:
|
||||||
|
sanitized_filename = "notes.md"
|
||||||
|
elif "." not in sanitized_filename:
|
||||||
|
sanitized_filename = f"{sanitized_filename}.md"
|
||||||
|
|
||||||
|
normalized_suggested_path = (
|
||||||
|
_normalize_file_path(suggested_path) if suggested_path else ""
|
||||||
|
)
|
||||||
|
if normalized_suggested_path:
|
||||||
|
if normalized_suggested_path.endswith("/"):
|
||||||
|
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
|
||||||
|
return normalized_suggested_path
|
||||||
|
|
||||||
|
directory = _normalize_directory(suggested_directory or "")
|
||||||
|
if not directory and inferred_dir:
|
||||||
|
directory = inferred_dir
|
||||||
|
if directory:
|
||||||
|
return f"/{directory}/{sanitized_filename}"
|
||||||
|
|
||||||
|
return f"/{sanitized_filename}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
|
||||||
|
return (
|
||||||
|
"Classify the latest user request into a filesystem intent for an AI agent.\n"
|
||||||
|
"Return JSON only with this exact schema:\n"
|
||||||
|
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- Use semantic intent, not literal keywords.\n"
|
||||||
|
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
|
||||||
|
"- file_read: user asks to open/read/list/search existing files.\n"
|
||||||
|
"- chat_only: conversational/analysis responses without required file operations.\n"
|
||||||
|
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
|
||||||
|
"- If the user mentions a folder/directory, populate suggested_directory.\n"
|
||||||
|
"- If user specifies an explicit full path, populate suggested_path.\n"
|
||||||
|
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
|
||||||
|
"- Do not use .txt; prefer .md for generic text notes.\n"
|
||||||
|
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
|
||||||
|
"- Never include markdown or explanation.\n\n"
|
||||||
|
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||||
|
f"Latest user message:\n{user_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_recent_conversation(
|
||||||
|
messages: list[BaseMessage], *, max_messages: int = 6
|
||||||
|
) -> str:
|
||||||
|
rows: list[str] = []
|
||||||
|
filtered: list[tuple[str, BaseMessage]] = []
|
||||||
|
for msg in messages:
|
||||||
|
role: str | None = None
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
role = "user"
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
if getattr(msg, "tool_calls", None):
|
||||||
|
continue
|
||||||
|
role = "assistant"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
filtered.append((role, msg))
|
||||||
|
for role, msg in filtered[-max_messages:]:
|
||||||
|
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||||
|
if text:
|
||||||
|
rows.append(f"{role}: {text[:280]}")
|
||||||
|
return "\n".join(rows)
|
||||||
|
|
||||||
|
|
||||||
|
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Classify file intent and inject a strict file-write contract."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
async def _classify_intent(
|
||||||
|
self, *, messages: list[BaseMessage], user_text: str
|
||||||
|
) -> FileIntentPlan:
|
||||||
|
if self.llm is None:
|
||||||
|
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||||
|
|
||||||
|
prompt = _build_classifier_prompt(
|
||||||
|
recent_conversation=_build_recent_conversation(messages),
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await self.llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal"]},
|
||||||
|
)
|
||||||
|
payload = json.loads(
|
||||||
|
_extract_json_payload(_extract_text_from_message(response))
|
||||||
|
)
|
||||||
|
plan = FileIntentPlan.model_validate(payload)
|
||||||
|
return plan
|
||||||
|
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||||
|
logger.warning("File intent classifier returned invalid output: %s", exc)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive fallback
|
||||||
|
logger.warning("File intent classifier failed: %s", exc)
|
||||||
|
|
||||||
|
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
last_human: HumanMessage | None = None
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
last_human = msg
|
||||||
|
break
|
||||||
|
if last_human is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_text = _extract_text_from_message(last_human).strip()
|
||||||
|
if not user_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
plan = await self._classify_intent(messages=messages, user_text=user_text)
|
||||||
|
suggested_path = _fallback_path(
|
||||||
|
plan.suggested_filename,
|
||||||
|
suggested_directory=plan.suggested_directory,
|
||||||
|
suggested_path=plan.suggested_path,
|
||||||
|
user_text=user_text,
|
||||||
|
)
|
||||||
|
contract = {
|
||||||
|
"intent": plan.intent.value,
|
||||||
|
"confidence": plan.confidence,
|
||||||
|
"suggested_path": suggested_path,
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"turn_id": state.get("turn_id", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
if plan.intent != FileOperationIntent.FILE_WRITE:
|
||||||
|
return {"file_operation_contract": contract}
|
||||||
|
|
||||||
|
contract_msg = SystemMessage(
|
||||||
|
content=(
|
||||||
|
"<file_operation_contract>\n"
|
||||||
|
"This turn intent is file_write.\n"
|
||||||
|
f"Suggested default path: {suggested_path}\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- You MUST call write_file or edit_file before claiming success.\n"
|
||||||
|
"- If no path is provided by the user, use the suggested default path.\n"
|
||||||
|
"- Do not claim a file was created/updated unless tool output confirms it.\n"
|
||||||
|
"- If the write/edit fails, clearly report failure instead of success.\n"
|
||||||
|
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
|
||||||
|
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
|
||||||
|
"</file_operation_contract>"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert just before the latest human turn so it applies to this request.
|
||||||
|
new_messages = list(messages)
|
||||||
|
insert_at = max(len(new_messages) - 1, 0)
|
||||||
|
new_messages.insert(insert_at, contract_msg)
|
||||||
|
return {"messages": new_messages, "file_operation_contract": contract}
|
||||||
File diff suppressed because it is too large
Load diff
1469
surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py
Normal file
1469
surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1,10 +1,24 @@
|
||||||
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
|
"""Hybrid-search priority middleware for the SurfSense new chat agent.
|
||||||
|
|
||||||
This middleware runs before the main agent loop and seeds a virtual filesystem
|
This middleware runs ``before_agent`` on every turn and writes:
|
||||||
(`files` state) with relevant documents retrieved via hybrid search. On each
|
|
||||||
turn the filesystem is *expanded* — new results merge with documents loaded
|
* ``state["kb_priority"]`` — the top-K most relevant documents for the
|
||||||
during prior turns — and a synthetic ``ls`` result is injected into the message
|
current user message, used to render a ``<priority_documents>`` system
|
||||||
history so the LLM is immediately aware of the current filesystem structure.
|
message immediately before the user turn.
|
||||||
|
* ``state["kb_matched_chunk_ids"]`` — internal hand-off mapping
|
||||||
|
(``Document.id`` → matched chunk IDs) consumed by
|
||||||
|
:class:`KBPostgresBackend._load_file_data` when the agent first reads each
|
||||||
|
document, so the XML wrapper can flag matched sections in
|
||||||
|
``<chunk_index>``.
|
||||||
|
|
||||||
|
The previous "scoped filesystem" behaviour (synthetic ``ls`` + state
|
||||||
|
``files`` seeding) is intentionally removed: documents are now lazy-loaded
|
||||||
|
from Postgres on demand, with the full workspace tree rendered separately
|
||||||
|
by :class:`KnowledgeTreeMiddleware`.
|
||||||
|
|
||||||
|
In anonymous mode the middleware skips hybrid search entirely and emits a
|
||||||
|
single-entry priority list pointing at the Redis-loaded document
|
||||||
|
(``state["kb_anon_doc"]``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -13,26 +27,33 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
PathIndex,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
)
|
||||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||||
from app.db import (
|
from app.db import (
|
||||||
NATIVE_TO_LEGACY_DOCTYPE,
|
NATIVE_TO_LEGACY_DOCTYPE,
|
||||||
Chunk,
|
Chunk,
|
||||||
Document,
|
Document,
|
||||||
Folder,
|
|
||||||
shielded_async_session,
|
shielded_async_session,
|
||||||
)
|
)
|
||||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
|
|
@ -69,7 +90,6 @@ class KBSearchPlan(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
"""Extract plain text from a message content."""
|
|
||||||
content = getattr(message, "content", "")
|
content = getattr(message, "content", "")
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
|
|
@ -84,19 +104,6 @@ def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
|
||||||
"""Convert arbitrary text into a filesystem-safe filename."""
|
|
||||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
|
||||||
name = re.sub(r"\s+", " ", name)
|
|
||||||
if not name:
|
|
||||||
name = fallback
|
|
||||||
if len(name) > 180:
|
|
||||||
name = name[:180].rstrip()
|
|
||||||
if not name.lower().endswith(".xml"):
|
|
||||||
name = f"{name}.xml"
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _render_recent_conversation(
|
def _render_recent_conversation(
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
*,
|
*,
|
||||||
|
|
@ -106,10 +113,9 @@ def _render_recent_conversation(
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Render recent dialogue for internal planning under a token budget.
|
"""Render recent dialogue for internal planning under a token budget.
|
||||||
|
|
||||||
Prefers the latest messages and uses the project's existing model-aware
|
Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that
|
||||||
token budgeting hooks when available on the LLM (`_count_tokens`,
|
injected ``SystemMessage`` artefacts (priority list, workspace tree,
|
||||||
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
|
file-write contract) don't pollute the planner prompt.
|
||||||
if token counting is unavailable.
|
|
||||||
"""
|
"""
|
||||||
rendered: list[tuple[str, str]] = []
|
rendered: list[tuple[str, str]] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|
@ -132,8 +138,6 @@ def _render_recent_conversation(
|
||||||
if not rendered:
|
if not rendered:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Exclude the latest user message from "recent conversation" because it is
|
|
||||||
# already passed separately as "Latest user message" in the planner prompt.
|
|
||||||
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
||||||
rendered = rendered[:-1]
|
rendered = rendered[:-1]
|
||||||
|
|
||||||
|
|
@ -215,8 +219,6 @@ def _render_recent_conversation(
|
||||||
selected_lines = candidate_lines
|
selected_lines = candidate_lines
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If the full message does not fit, keep as much of this most-recent
|
|
||||||
# older message as possible via binary search.
|
|
||||||
lo, hi = 1, len(text)
|
lo, hi = 1, len(text)
|
||||||
best_line: str | None = None
|
best_line: str | None = None
|
||||||
while lo <= hi:
|
while lo <= hi:
|
||||||
|
|
@ -248,7 +250,6 @@ def _build_kb_planner_prompt(
|
||||||
recent_conversation: str,
|
recent_conversation: str,
|
||||||
user_text: str,
|
user_text: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a compact internal prompt for KB query rewriting and date scoping."""
|
|
||||||
today = datetime.now(UTC).date().isoformat()
|
today = datetime.now(UTC).date().isoformat()
|
||||||
return (
|
return (
|
||||||
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||||
|
|
@ -274,12 +275,10 @@ def _build_kb_planner_prompt(
|
||||||
|
|
||||||
|
|
||||||
def _extract_json_payload(text: str) -> str:
|
def _extract_json_payload(text: str) -> str:
|
||||||
"""Extract a JSON object from a raw LLM response."""
|
|
||||||
stripped = text.strip()
|
stripped = text.strip()
|
||||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||||
if fenced:
|
if fenced:
|
||||||
return fenced.group(1)
|
return fenced.group(1)
|
||||||
|
|
||||||
start = stripped.find("{")
|
start = stripped.find("{")
|
||||||
end = stripped.rfind("}")
|
end = stripped.rfind("}")
|
||||||
if start != -1 and end != -1 and end > start:
|
if start != -1 and end != -1 and end > start:
|
||||||
|
|
@ -288,7 +287,6 @@ def _extract_json_payload(text: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
||||||
"""Parse and validate the planner's JSON response."""
|
|
||||||
payload = json.loads(_extract_json_payload(response_text))
|
payload = json.loads(_extract_json_payload(response_text))
|
||||||
return KBSearchPlan.model_validate(payload)
|
return KBSearchPlan.model_validate(payload)
|
||||||
|
|
||||||
|
|
@ -297,212 +295,19 @@ def _normalize_optional_date_range(
|
||||||
start_date: str | None,
|
start_date: str | None,
|
||||||
end_date: str | None,
|
end_date: str | None,
|
||||||
) -> tuple[datetime | None, datetime | None]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
"""Normalize optional planner dates into a UTC datetime range."""
|
|
||||||
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
||||||
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
||||||
|
|
||||||
if parsed_start is None and parsed_end is None:
|
if parsed_start is None and parsed_end is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
|
return resolve_date_range(parsed_start, parsed_end)
|
||||||
return resolved_start, resolved_end
|
|
||||||
|
|
||||||
|
|
||||||
def _build_document_xml(
|
|
||||||
document: dict[str, Any],
|
|
||||||
matched_chunk_ids: set[int] | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
|
||||||
|
|
||||||
The ``<chunk_index>`` at the top of each document lists every chunk with its
|
|
||||||
line range inside ``<document_content>`` and flags chunks that directly
|
|
||||||
matched the search query (``matched="true"``). This lets the LLM jump
|
|
||||||
straight to the most relevant section via ``read_file(offset=…, limit=…)``
|
|
||||||
instead of reading sequentially from the start.
|
|
||||||
"""
|
|
||||||
matched = matched_chunk_ids or set()
|
|
||||||
|
|
||||||
doc_meta = document.get("document") or {}
|
|
||||||
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
|
||||||
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
|
||||||
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
|
||||||
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
|
||||||
url = (
|
|
||||||
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
|
||||||
)
|
|
||||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
|
||||||
|
|
||||||
# --- 1. Metadata header (fixed structure) ---
|
|
||||||
metadata_lines: list[str] = [
|
|
||||||
"<document>",
|
|
||||||
"<document_metadata>",
|
|
||||||
f" <document_id>{document_id}</document_id>",
|
|
||||||
f" <document_type>{document_type}</document_type>",
|
|
||||||
f" <title><![CDATA[{title}]]></title>",
|
|
||||||
f" <url><![CDATA[{url}]]></url>",
|
|
||||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
|
||||||
"</document_metadata>",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- 2. Pre-build chunk XML strings to compute line counts ---
|
|
||||||
chunks = document.get("chunks") or []
|
|
||||||
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
|
|
||||||
if isinstance(chunks, list):
|
|
||||||
for chunk in chunks:
|
|
||||||
if not isinstance(chunk, dict):
|
|
||||||
continue
|
|
||||||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
|
||||||
chunk_content = str(chunk.get("content", "")).strip()
|
|
||||||
if not chunk_content:
|
|
||||||
continue
|
|
||||||
if chunk_id is None:
|
|
||||||
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
|
||||||
else:
|
|
||||||
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
|
||||||
chunk_entries.append((chunk_id, xml))
|
|
||||||
|
|
||||||
# --- 3. Compute line numbers for every chunk ---
|
|
||||||
# Layout (1-indexed lines for read_file):
|
|
||||||
# metadata_lines -> len(metadata_lines) lines
|
|
||||||
# <chunk_index> -> 1 line
|
|
||||||
# index entries -> len(chunk_entries) lines
|
|
||||||
# </chunk_index> -> 1 line
|
|
||||||
# (empty line) -> 1 line
|
|
||||||
# <document_content> -> 1 line
|
|
||||||
# chunk xml lines…
|
|
||||||
# </document_content> -> 1 line
|
|
||||||
# </document> -> 1 line
|
|
||||||
index_overhead = (
|
|
||||||
1 + len(chunk_entries) + 1 + 1 + 1
|
|
||||||
) # tags + empty + <document_content>
|
|
||||||
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
|
|
||||||
|
|
||||||
current_line = first_chunk_line
|
|
||||||
index_entry_lines: list[str] = []
|
|
||||||
for cid, xml_str in chunk_entries:
|
|
||||||
num_lines = xml_str.count("\n") + 1
|
|
||||||
end_line = current_line + num_lines - 1
|
|
||||||
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
|
||||||
if cid is not None:
|
|
||||||
index_entry_lines.append(
|
|
||||||
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
index_entry_lines.append(
|
|
||||||
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
|
||||||
)
|
|
||||||
current_line = end_line + 1
|
|
||||||
|
|
||||||
# --- 4. Assemble final XML ---
|
|
||||||
lines = metadata_lines.copy()
|
|
||||||
lines.append("<chunk_index>")
|
|
||||||
lines.extend(index_entry_lines)
|
|
||||||
lines.append("</chunk_index>")
|
|
||||||
lines.append("")
|
|
||||||
lines.append("<document_content>")
|
|
||||||
for _, xml_str in chunk_entries:
|
|
||||||
lines.append(xml_str)
|
|
||||||
lines.extend(["</document_content>", "</document>"])
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_folder_paths(
|
|
||||||
session: AsyncSession, search_space_id: int
|
|
||||||
) -> dict[int, str]:
|
|
||||||
"""Return a map of folder_id -> virtual folder path under /documents."""
|
|
||||||
result = await session.execute(
|
|
||||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
|
||||||
Folder.search_space_id == search_space_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
rows = result.all()
|
|
||||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
|
||||||
|
|
||||||
cache: dict[int, str] = {}
|
|
||||||
|
|
||||||
def resolve_path(folder_id: int) -> str:
|
|
||||||
if folder_id in cache:
|
|
||||||
return cache[folder_id]
|
|
||||||
parts: list[str] = []
|
|
||||||
cursor: int | None = folder_id
|
|
||||||
visited: set[int] = set()
|
|
||||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
|
||||||
visited.add(cursor)
|
|
||||||
entry = by_id[cursor]
|
|
||||||
parts.append(
|
|
||||||
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
|
|
||||||
".xml"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cursor = entry["parent_id"]
|
|
||||||
parts.reverse()
|
|
||||||
path = "/documents/" + "/".join(parts) if parts else "/documents"
|
|
||||||
cache[folder_id] = path
|
|
||||||
return path
|
|
||||||
|
|
||||||
for folder_id in by_id:
|
|
||||||
resolve_path(folder_id)
|
|
||||||
return cache
|
|
||||||
|
|
||||||
|
|
||||||
def _build_synthetic_ls(
|
|
||||||
existing_files: dict[str, Any] | None,
|
|
||||||
new_files: dict[str, Any],
|
|
||||||
*,
|
|
||||||
mentioned_paths: set[str] | None = None,
|
|
||||||
) -> tuple[AIMessage, ToolMessage]:
|
|
||||||
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
|
|
||||||
|
|
||||||
Mentioned files are listed first. A separate header tells the LLM which
|
|
||||||
files the user explicitly selected; the path list itself stays clean so
|
|
||||||
paths can be passed directly to ``read_file`` without stripping tags.
|
|
||||||
"""
|
|
||||||
_mentioned = mentioned_paths or set()
|
|
||||||
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
|
|
||||||
doc_paths = [
|
|
||||||
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
new_set = set(new_files)
|
|
||||||
mentioned_list = [p for p in doc_paths if p in _mentioned]
|
|
||||||
new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned]
|
|
||||||
old_paths = [p for p in doc_paths if p not in new_set]
|
|
||||||
ordered = mentioned_list + new_non_mentioned + old_paths
|
|
||||||
|
|
||||||
parts: list[str] = []
|
|
||||||
if mentioned_list:
|
|
||||||
parts.append(
|
|
||||||
"USER-MENTIONED documents (read these thoroughly before answering):"
|
|
||||||
)
|
|
||||||
for p in mentioned_list:
|
|
||||||
parts.append(f" {p}")
|
|
||||||
parts.append("")
|
|
||||||
parts.append(str(ordered) if ordered else "No documents found.")
|
|
||||||
|
|
||||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
|
||||||
ai_msg = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
|
|
||||||
)
|
|
||||||
tool_msg = ToolMessage(
|
|
||||||
content="\n".join(parts),
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
)
|
|
||||||
return ai_msg, tool_msg
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_search_types(
|
def _resolve_search_types(
|
||||||
available_connectors: list[str] | None,
|
available_connectors: list[str] | None,
|
||||||
available_document_types: list[str] | None,
|
available_document_types: list[str] | None,
|
||||||
) -> list[str] | None:
|
) -> list[str] | None:
|
||||||
"""Build a flat list of document-type strings for the chunk retriever.
|
|
||||||
|
|
||||||
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
|
|
||||||
old documents indexed under Composio names are still found.
|
|
||||||
|
|
||||||
Returns ``None`` when no filtering is desired (search all types).
|
|
||||||
"""
|
|
||||||
types: set[str] = set()
|
types: set[str] = set()
|
||||||
if available_document_types:
|
if available_document_types:
|
||||||
types.update(available_document_types)
|
types.update(available_document_types)
|
||||||
|
|
@ -530,13 +335,8 @@ async def browse_recent_documents(
|
||||||
start_date: datetime | None = None,
|
start_date: datetime | None = None,
|
||||||
end_date: datetime | None = None,
|
end_date: datetime | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Return documents ordered by recency (newest first), no relevance ranking.
|
"""Return documents ordered by recency (newest first), no relevance ranking."""
|
||||||
|
from sqlalchemy import func
|
||||||
Used when the user's intent is temporal ("latest file", "most recent upload")
|
|
||||||
and hybrid search would produce poor results because the query has no
|
|
||||||
meaningful topical signal.
|
|
||||||
"""
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
|
|
||||||
from app.db import DocumentType
|
from app.db import DocumentType
|
||||||
|
|
||||||
|
|
@ -580,7 +380,6 @@ async def browse_recent_documents(
|
||||||
return []
|
return []
|
||||||
|
|
||||||
doc_ids = [d.id for d in documents]
|
doc_ids = [d.id for d in documents]
|
||||||
|
|
||||||
numbered = (
|
numbered = (
|
||||||
select(
|
select(
|
||||||
Chunk.id.label("chunk_id"),
|
Chunk.id.label("chunk_id"),
|
||||||
|
|
@ -631,6 +430,7 @@ async def browse_recent_documents(
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
|
"folder_id": getattr(doc, "folder_id", None),
|
||||||
},
|
},
|
||||||
"source": (
|
"source": (
|
||||||
doc.document_type.value
|
doc.document_type.value
|
||||||
|
|
@ -639,12 +439,6 @@ async def browse_recent_documents(
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"browse_recent_documents: %d docs returned for space=%d",
|
|
||||||
len(results),
|
|
||||||
search_space_id,
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -658,17 +452,11 @@ async def search_knowledge_base(
|
||||||
start_date: datetime | None = None,
|
start_date: datetime | None = None,
|
||||||
end_date: datetime | None = None,
|
end_date: datetime | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Run a single unified hybrid search against the knowledge base.
|
"""Run a single unified hybrid search against the knowledge base."""
|
||||||
|
|
||||||
Uses one ``ChucksHybridSearchRetriever`` call across all document types
|
|
||||||
instead of fanning out per-connector. This reduces the number of DB
|
|
||||||
queries from ~10 to 2 (one RRF query + one chunk fetch).
|
|
||||||
"""
|
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
[embedding] = embed_texts([query])
|
[embedding] = embed_texts([query])
|
||||||
|
|
||||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||||
retriever_top_k = min(top_k * 3, 30)
|
retriever_top_k = min(top_k * 3, 30)
|
||||||
|
|
||||||
|
|
@ -692,14 +480,7 @@ async def fetch_mentioned_documents(
|
||||||
document_ids: list[int],
|
document_ids: list[int],
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Fetch explicitly mentioned documents with *all* their chunks.
|
"""Fetch explicitly mentioned documents."""
|
||||||
|
|
||||||
Returns the same dict structure as ``search_knowledge_base`` so results
|
|
||||||
can be merged directly into ``build_scoped_filesystem``. Unlike search
|
|
||||||
results, every chunk is included (no top-K limiting) and none are marked
|
|
||||||
as ``matched`` since the entire document is relevant by virtue of the
|
|
||||||
user's explicit mention.
|
|
||||||
"""
|
|
||||||
if not document_ids:
|
if not document_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -749,6 +530,7 @@ async def fetch_mentioned_documents(
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
|
"folder_id": getattr(doc, "folder_id", None),
|
||||||
},
|
},
|
||||||
"source": (
|
"source": (
|
||||||
doc.document_type.value
|
doc.document_type.value
|
||||||
|
|
@ -761,115 +543,100 @@ async def fetch_mentioned_documents(
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def build_scoped_filesystem(
|
def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
|
||||||
*,
|
"""Render the priority list as a single ``<priority_documents>`` system message."""
|
||||||
documents: Sequence[dict[str, Any]],
|
if not priority:
|
||||||
search_space_id: int,
|
body = "(no priority documents for this turn)"
|
||||||
) -> tuple[dict[str, dict[str, str]], dict[int, str]]:
|
else:
|
||||||
"""Build a StateBackend-compatible files dict from search results.
|
lines: list[str] = []
|
||||||
|
for entry in priority:
|
||||||
Returns ``(files, doc_id_to_path)`` so callers can reliably map a
|
score = entry.get("score")
|
||||||
document id back to its filesystem path without guessing by title.
|
mentioned = entry.get("mentioned")
|
||||||
Paths are collision-proof: when two documents resolve to the same
|
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
|
||||||
path the doc-id is appended to disambiguate.
|
mark = " [USER-MENTIONED]" if mentioned else ""
|
||||||
"""
|
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
||||||
async with shielded_async_session() as session:
|
body = "\n".join(lines)
|
||||||
folder_paths = await _get_folder_paths(session, search_space_id)
|
return SystemMessage(
|
||||||
doc_ids = [
|
content=(
|
||||||
(doc.get("document") or {}).get("id")
|
"<priority_documents>\n"
|
||||||
for doc in documents
|
"These documents are most relevant to the latest user message; "
|
||||||
if isinstance(doc, dict)
|
"read them first. Matched sections are flagged inside each "
|
||||||
]
|
"document's <chunk_index>.\n"
|
||||||
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
f"{body}\n"
|
||||||
folder_by_doc_id: dict[int, int | None] = {}
|
"</priority_documents>"
|
||||||
if doc_ids:
|
)
|
||||||
doc_rows = await session.execute(
|
)
|
||||||
select(Document.id, Document.folder_id).where(
|
|
||||||
Document.search_space_id == search_space_id,
|
|
||||||
Document.id.in_(doc_ids),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
folder_by_doc_id = {
|
|
||||||
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
files: dict[str, dict[str, str]] = {}
|
|
||||||
doc_id_to_path: dict[int, str] = {}
|
|
||||||
for document in documents:
|
|
||||||
doc_meta = document.get("document") or {}
|
|
||||||
title = str(doc_meta.get("title") or "untitled")
|
|
||||||
doc_id = doc_meta.get("id")
|
|
||||||
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
|
|
||||||
base_folder = folder_paths.get(folder_id, "/documents")
|
|
||||||
file_name = _safe_filename(title)
|
|
||||||
path = f"{base_folder}/{file_name}"
|
|
||||||
if path in files:
|
|
||||||
stem = file_name.removesuffix(".xml")
|
|
||||||
path = f"{base_folder}/{stem} ({doc_id}).xml"
|
|
||||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
|
||||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
|
||||||
files[path] = {
|
|
||||||
"content": xml_content.split("\n"),
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"created_at": "",
|
|
||||||
"modified_at": "",
|
|
||||||
}
|
|
||||||
if isinstance(doc_id, int):
|
|
||||||
doc_id_to_path[doc_id] = path
|
|
||||||
return files, doc_id_to_path
|
|
||||||
|
|
||||||
|
|
||||||
def _build_anon_scoped_filesystem(
|
class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
documents: Sequence[dict[str, Any]],
|
"""Compute hybrid-search priority hints for the current turn."""
|
||||||
) -> dict[str, dict[str, str]]:
|
|
||||||
"""Build a scoped filesystem for anonymous documents without DB queries.
|
|
||||||
|
|
||||||
Anonymous uploads have no folders, so all files go under /documents.
|
|
||||||
"""
|
|
||||||
files: dict[str, dict[str, str]] = {}
|
|
||||||
for document in documents:
|
|
||||||
doc_meta = document.get("document") or {}
|
|
||||||
title = str(doc_meta.get("title") or "untitled")
|
|
||||||
file_name = _safe_filename(title)
|
|
||||||
path = f"/documents/{file_name}"
|
|
||||||
if path in files:
|
|
||||||
doc_id = doc_meta.get("id", "dup")
|
|
||||||
stem = file_name.removesuffix(".xml")
|
|
||||||
path = f"/documents/{stem} ({doc_id}).xml"
|
|
||||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
|
||||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
|
||||||
files[path] = {
|
|
||||||
"content": xml_content.split("\n"),
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"created_at": "",
|
|
||||||
"modified_at": "",
|
|
||||||
}
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|
||||||
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
|
|
||||||
|
|
||||||
tools = ()
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
llm: BaseChatModel | None = None,
|
llm: BaseChatModel | None = None,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
anon_session_id: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
|
self.filesystem_mode = filesystem_mode
|
||||||
self.available_connectors = available_connectors
|
self.available_connectors = available_connectors
|
||||||
self.available_document_types = available_document_types
|
self.available_document_types = available_document_types
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.mentioned_document_ids = mentioned_document_ids or []
|
self.mentioned_document_ids = mentioned_document_ids or []
|
||||||
self.anon_session_id = anon_session_id
|
# Build the kb-planner private Runnable ONCE here so we don't pay
|
||||||
|
# the ``create_agent`` compile cost (50-200ms) on every turn.
|
||||||
|
# Disabled by default behind ``enable_kb_planner_runnable``; when
|
||||||
|
# off the planner falls back to the legacy ``self.llm.ainvoke``
|
||||||
|
# path.
|
||||||
|
self._planner: Runnable | None = None
|
||||||
|
self._planner_compile_failed = False
|
||||||
|
|
||||||
|
def _build_kb_planner_runnable(self) -> Runnable | None:
|
||||||
|
"""Compile the kb-planner private :class:`Runnable` once.
|
||||||
|
|
||||||
|
Returns ``None`` when the feature flag is disabled, when the LLM is
|
||||||
|
unavailable, or when ``create_agent`` raises (we fall back to the
|
||||||
|
legacy ``self.llm.ainvoke`` path in that case). Compilation happens
|
||||||
|
lazily on first call, then memoized via ``self._planner``.
|
||||||
|
|
||||||
|
The compiled agent is constructed without tools — the planner's
|
||||||
|
contract is "answer with structured JSON" — but it inherits the
|
||||||
|
:class:`RetryAfterMiddleware` so transient rate-limit errors
|
||||||
|
from the planner LLM call don't fail the whole turn.
|
||||||
|
"""
|
||||||
|
if self._planner is not None or self._planner_compile_failed:
|
||||||
|
return self._planner
|
||||||
|
if self.llm is None:
|
||||||
|
return None
|
||||||
|
flags = get_flags()
|
||||||
|
if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._planner = create_agent(
|
||||||
|
self.llm,
|
||||||
|
tools=[],
|
||||||
|
middleware=[RetryAfterMiddleware(max_retries=2)],
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning(
|
||||||
|
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
self._planner_compile_failed = True
|
||||||
|
self._planner = None
|
||||||
|
return self._planner
|
||||||
|
|
||||||
async def _plan_search_inputs(
|
async def _plan_search_inputs(
|
||||||
self,
|
self,
|
||||||
|
|
@ -877,10 +644,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
user_text: str,
|
user_text: str,
|
||||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||||
"""Rewrite the KB query and infer optional date filters with the LLM.
|
|
||||||
|
|
||||||
Returns (optimized_query, start_date, end_date, is_recency_query).
|
|
||||||
"""
|
|
||||||
if self.llm is None:
|
if self.llm is None:
|
||||||
return user_text, None, None, False
|
return user_text, None, None, False
|
||||||
|
|
||||||
|
|
@ -896,11 +659,32 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
t0 = loop.time()
|
t0 = loop.time()
|
||||||
|
|
||||||
|
# Prefer the compiled-once planner Runnable when enabled; otherwise
|
||||||
|
# fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag
|
||||||
|
# is preserved on both paths so ``_stream_agent_events`` still
|
||||||
|
# suppresses the planner's intermediate events from the UI.
|
||||||
|
planner = self._build_kb_planner_runnable()
|
||||||
try:
|
try:
|
||||||
response = await self.llm.ainvoke(
|
if planner is not None:
|
||||||
[HumanMessage(content=prompt)],
|
planner_state = await planner.ainvoke(
|
||||||
config={"tags": ["surfsense:internal"]},
|
{"messages": [HumanMessage(content=prompt)]},
|
||||||
)
|
config={"tags": ["surfsense:internal"]},
|
||||||
|
)
|
||||||
|
response_messages = (
|
||||||
|
planner_state.get("messages", [])
|
||||||
|
if isinstance(planner_state, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
response = (
|
||||||
|
response_messages[-1]
|
||||||
|
if response_messages
|
||||||
|
else AIMessage(content="")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self.llm.ainvoke(
|
||||||
|
[HumanMessage(content=prompt)],
|
||||||
|
config={"tags": ["surfsense:internal"]},
|
||||||
|
)
|
||||||
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
|
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
|
||||||
optimized_query = (
|
optimized_query = (
|
||||||
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
|
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
|
||||||
|
|
@ -911,7 +695,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
is_recency = plan.is_recency_query
|
is_recency = plan.is_recency_query
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
|
"[kb_priority] planner in %.3fs query=%r optimized=%r "
|
||||||
"start=%s end=%s recency=%s",
|
"start=%s end=%s recency=%s",
|
||||||
loop.time() - t0,
|
loop.time() - t0,
|
||||||
user_text[:80],
|
user_text[:80],
|
||||||
|
|
@ -943,103 +727,68 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
pass
|
pass
|
||||||
return asyncio.run(self.abefore_agent(state, runtime))
|
return asyncio.run(self.abefore_agent(state, runtime))
|
||||||
|
|
||||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
|
||||||
"""Load the anonymous user's uploaded document from Redis."""
|
|
||||||
if not self.anon_session_id:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
|
|
||||||
redis_client = aioredis.from_url(
|
|
||||||
config.REDIS_APP_URL, decode_responses=True
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
|
||||||
data = await redis_client.get(redis_key)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
doc = json.loads(data)
|
|
||||||
return {
|
|
||||||
"document_id": -1,
|
|
||||||
"content": doc.get("content", ""),
|
|
||||||
"score": 1.0,
|
|
||||||
"chunks": [
|
|
||||||
{
|
|
||||||
"chunk_id": -1,
|
|
||||||
"content": doc.get("content", ""),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"matched_chunk_ids": [-1],
|
|
||||||
"document": {
|
|
||||||
"id": -1,
|
|
||||||
"title": doc.get("filename", "uploaded_document"),
|
|
||||||
"document_type": "FILE",
|
|
||||||
"metadata": {"source": "anonymous_upload"},
|
|
||||||
},
|
|
||||||
"source": "FILE",
|
|
||||||
"_user_mentioned": True,
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
await redis_client.aclose()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def abefore_agent( # type: ignore[override]
|
async def abefore_agent( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
del runtime
|
del runtime
|
||||||
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
|
||||||
messages = state.get("messages") or []
|
messages = state.get("messages") or []
|
||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
last_human = None
|
last_human: HumanMessage | None = None
|
||||||
for msg in reversed(messages):
|
for msg in reversed(messages):
|
||||||
if isinstance(msg, HumanMessage):
|
if isinstance(msg, HumanMessage):
|
||||||
last_human = msg
|
last_human = msg
|
||||||
break
|
break
|
||||||
if last_human is None:
|
if last_human is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
user_text = _extract_text_from_message(last_human).strip()
|
user_text = _extract_text_from_message(last_human).strip()
|
||||||
if not user_text:
|
if not user_text:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
anon_doc = state.get("kb_anon_doc")
|
||||||
existing_files = state.get("files")
|
if anon_doc:
|
||||||
|
return self._anon_priority(state, anon_doc)
|
||||||
|
|
||||||
# --- Anonymous session: load Redis doc and skip DB queries ---
|
return await self._authenticated_priority(state, messages, user_text)
|
||||||
if self.anon_session_id:
|
|
||||||
merged: list[dict[str, Any]] = []
|
|
||||||
anon_doc = await self._load_anon_document()
|
|
||||||
if anon_doc:
|
|
||||||
merged.append(anon_doc)
|
|
||||||
|
|
||||||
if merged:
|
def _anon_priority(
|
||||||
new_files = _build_anon_scoped_filesystem(merged)
|
self,
|
||||||
mentioned_paths = set(new_files.keys())
|
state: AgentState,
|
||||||
else:
|
anon_doc: dict[str, Any],
|
||||||
new_files = {}
|
) -> dict[str, Any]:
|
||||||
mentioned_paths = set()
|
path = str(anon_doc.get("path") or "")
|
||||||
|
title = str(anon_doc.get("title") or "uploaded_document")
|
||||||
|
priority = [
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"score": 1.0,
|
||||||
|
"document_id": None,
|
||||||
|
"title": title,
|
||||||
|
"mentioned": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
new_messages = list(state.get("messages") or [])
|
||||||
|
insert_at = max(len(new_messages) - 1, 0)
|
||||||
|
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||||
|
return {
|
||||||
|
"kb_priority": priority,
|
||||||
|
"kb_matched_chunk_ids": {},
|
||||||
|
"messages": new_messages,
|
||||||
|
}
|
||||||
|
|
||||||
ai_msg, tool_msg = _build_synthetic_ls(
|
async def _authenticated_priority(
|
||||||
existing_files,
|
self,
|
||||||
new_files,
|
state: AgentState,
|
||||||
mentioned_paths=mentioned_paths,
|
messages: Sequence[BaseMessage],
|
||||||
)
|
user_text: str,
|
||||||
if t0 is not None:
|
) -> dict[str, Any]:
|
||||||
_perf_log.info(
|
t0 = asyncio.get_event_loop().time()
|
||||||
"[kb_fs_middleware] anon completed in %.3fs new_files=%d",
|
|
||||||
asyncio.get_event_loop().time() - t0,
|
|
||||||
len(new_files),
|
|
||||||
)
|
|
||||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
|
||||||
|
|
||||||
# --- Authenticated session: full KB search ---
|
|
||||||
(
|
(
|
||||||
planned_query,
|
planned_query,
|
||||||
start_date,
|
start_date,
|
||||||
|
|
@ -1050,7 +799,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 1. Fetch mentioned documents (user-selected, all chunks) ---
|
|
||||||
mentioned_results: list[dict[str, Any]] = []
|
mentioned_results: list[dict[str, Any]] = []
|
||||||
if self.mentioned_document_ids:
|
if self.mentioned_document_ids:
|
||||||
mentioned_results = await fetch_mentioned_documents(
|
mentioned_results = await fetch_mentioned_documents(
|
||||||
|
|
@ -1059,7 +807,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
self.mentioned_document_ids = []
|
self.mentioned_document_ids = []
|
||||||
|
|
||||||
# --- 2. Run KB search (recency browse or hybrid) ---
|
|
||||||
if is_recency:
|
if is_recency:
|
||||||
doc_types = _resolve_search_types(
|
doc_types = _resolve_search_types(
|
||||||
self.available_connectors, self.available_document_types
|
self.available_connectors, self.available_document_types
|
||||||
|
|
@ -1082,48 +829,108 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
|
||||||
seen_doc_ids: set[int] = set()
|
seen_doc_ids: set[int] = set()
|
||||||
merged_auth: list[dict[str, Any]] = []
|
merged: list[dict[str, Any]] = []
|
||||||
for doc in mentioned_results:
|
for doc in mentioned_results:
|
||||||
doc_id = (doc.get("document") or {}).get("id")
|
doc_id = (doc.get("document") or {}).get("id")
|
||||||
if doc_id is not None:
|
if isinstance(doc_id, int):
|
||||||
seen_doc_ids.add(doc_id)
|
seen_doc_ids.add(doc_id)
|
||||||
merged_auth.append(doc)
|
merged.append(doc)
|
||||||
for doc in search_results:
|
for doc in search_results:
|
||||||
doc_id = (doc.get("document") or {}).get("id")
|
doc_id = (doc.get("document") or {}).get("id")
|
||||||
if doc_id is not None and doc_id in seen_doc_ids:
|
if isinstance(doc_id, int) and doc_id in seen_doc_ids:
|
||||||
continue
|
continue
|
||||||
merged_auth.append(doc)
|
merged.append(doc)
|
||||||
|
|
||||||
# --- 4. Build scoped filesystem ---
|
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
||||||
new_files, doc_id_to_path = await build_scoped_filesystem(
|
|
||||||
documents=merged_auth,
|
new_messages = list(messages)
|
||||||
search_space_id=self.search_space_id,
|
insert_at = max(len(new_messages) - 1, 0)
|
||||||
|
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
|
||||||
|
asyncio.get_event_loop().time() - t0,
|
||||||
|
user_text[:80],
|
||||||
|
len(priority),
|
||||||
|
len(mentioned_results),
|
||||||
)
|
)
|
||||||
|
|
||||||
mentioned_doc_ids = {
|
return {
|
||||||
(d.get("document") or {}).get("id") for d in mentioned_results
|
"kb_priority": priority,
|
||||||
}
|
"kb_matched_chunk_ids": matched_chunk_ids,
|
||||||
mentioned_paths = {
|
"messages": new_messages,
|
||||||
doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ai_msg, tool_msg = _build_synthetic_ls(
|
async def _materialize_priority(
|
||||||
existing_files,
|
self, merged: list[dict[str, Any]]
|
||||||
new_files,
|
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
|
||||||
mentioned_paths=mentioned_paths,
|
"""Resolve canonical paths and matched chunk ids for the priority list."""
|
||||||
)
|
priority: list[dict[str, Any]] = []
|
||||||
|
matched_chunk_ids: dict[int, list[int]] = {}
|
||||||
|
|
||||||
if t0 is not None:
|
if not merged:
|
||||||
_perf_log.info(
|
return priority, matched_chunk_ids
|
||||||
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r "
|
|
||||||
"mentioned=%d new_files=%d total=%d",
|
async with shielded_async_session() as session:
|
||||||
asyncio.get_event_loop().time() - t0,
|
index: PathIndex = await build_path_index(session, self.search_space_id)
|
||||||
user_text[:80],
|
doc_ids = [
|
||||||
planned_query[:120],
|
(doc.get("document") or {}).get("id")
|
||||||
len(mentioned_results),
|
for doc in merged
|
||||||
len(new_files),
|
if isinstance(doc, dict)
|
||||||
len(new_files) + len(existing_files or {}),
|
]
|
||||||
|
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
||||||
|
folder_by_doc_id: dict[int, int | None] = {}
|
||||||
|
if doc_ids:
|
||||||
|
folder_rows = await session.execute(
|
||||||
|
select(Document.id, Document.folder_id).where(
|
||||||
|
Document.search_space_id == self.search_space_id,
|
||||||
|
Document.id.in_(doc_ids),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
|
||||||
|
|
||||||
|
for doc in merged:
|
||||||
|
doc_meta = doc.get("document") or {}
|
||||||
|
doc_id = doc_meta.get("id")
|
||||||
|
title = doc_meta.get("title") or "untitled"
|
||||||
|
folder_id = (
|
||||||
|
folder_by_doc_id.get(doc_id)
|
||||||
|
if isinstance(doc_id, int)
|
||||||
|
else doc_meta.get("folder_id")
|
||||||
)
|
)
|
||||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
path = doc_to_virtual_path(
|
||||||
|
doc_id=doc_id if isinstance(doc_id, int) else None,
|
||||||
|
title=str(title),
|
||||||
|
folder_id=folder_id if isinstance(folder_id, int) else None,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
priority.append(
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"score": float(doc.get("score") or 0.0),
|
||||||
|
"document_id": doc_id if isinstance(doc_id, int) else None,
|
||||||
|
"title": str(title),
|
||||||
|
"mentioned": bool(doc.get("_user_mentioned")),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if isinstance(doc_id, int):
|
||||||
|
chunk_ids = doc.get("matched_chunk_ids") or []
|
||||||
|
if chunk_ids:
|
||||||
|
matched_chunk_ids[doc_id] = [
|
||||||
|
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
|
||||||
|
]
|
||||||
|
return priority, matched_chunk_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Backwards-compatible alias for any external imports.
|
||||||
|
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KnowledgeBaseSearchMiddleware",
|
||||||
|
"KnowledgePriorityMiddleware",
|
||||||
|
"browse_recent_documents",
|
||||||
|
"fetch_mentioned_documents",
|
||||||
|
"search_knowledge_base",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,310 @@
|
||||||
|
"""Workspace-tree middleware for the SurfSense agent.
|
||||||
|
|
||||||
|
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
|
||||||
|
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
|
||||||
|
injects the result as a ``<workspace_tree>`` system message immediately
|
||||||
|
before the latest human turn.
|
||||||
|
|
||||||
|
The render is bounded by two truncation layers:
|
||||||
|
|
||||||
|
1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is
|
||||||
|
replaced with a "use ls" hint.
|
||||||
|
2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
|
||||||
|
token-count profile when available). If the entry-truncated tree still
|
||||||
|
exceeds the token cap we fall back to a root-only summary.
|
||||||
|
|
||||||
|
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
|
||||||
|
|
||||||
|
This middleware also performs a one-time initialization of ``state['cwd']``
|
||||||
|
to ``"/documents"`` so subsequent middlewares and tools always see a valid
|
||||||
|
cwd in cloud mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
PathIndex,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
)
|
||||||
|
from app.db import Document, shielded_async_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
from litellm import token_counter
|
||||||
|
except Exception: # pragma: no cover - optional dep
|
||||||
|
token_counter = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
MAX_TREE_ENTRIES = 500
|
||||||
|
MAX_TREE_TOKENS = 4000
|
||||||
|
|
||||||
|
|
||||||
|
def _approx_tokens(text: str) -> int:
|
||||||
|
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
|
||||||
|
return max(1, (len(text) + 3) // 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
|
||||||
|
if llm is None:
|
||||||
|
return _approx_tokens(text)
|
||||||
|
count_fn = getattr(llm, "_count_tokens", None)
|
||||||
|
if callable(count_fn):
|
||||||
|
try:
|
||||||
|
return int(count_fn([{"role": "user", "content": text}]))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
profile = getattr(llm, "profile", None)
|
||||||
|
model_names: list[str] = []
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
tcms = profile.get("token_count_models")
|
||||||
|
if isinstance(tcms, list):
|
||||||
|
model_names.extend(name for name in tcms if isinstance(name, str) and name)
|
||||||
|
tcm = profile.get("token_count_model")
|
||||||
|
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||||
|
model_names.append(tcm)
|
||||||
|
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||||
|
if not isinstance(model_name, str) or not model_name or token_counter is None:
|
||||||
|
return _approx_tokens(text)
|
||||||
|
try:
|
||||||
|
return int(
|
||||||
|
token_counter(
|
||||||
|
messages=[{"role": "user", "content": text}],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return _approx_tokens(text)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Inject the workspace folder/document tree into the agent's context."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
llm: BaseChatModel | None = None,
|
||||||
|
max_entries: int = MAX_TREE_ENTRIES,
|
||||||
|
max_tokens: int = MAX_TREE_TOKENS,
|
||||||
|
) -> None:
|
||||||
|
self.search_space_id = search_space_id
|
||||||
|
self.filesystem_mode = filesystem_mode
|
||||||
|
self.llm = llm
|
||||||
|
self.max_entries = max_entries
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self._cache: dict[tuple[int, int, bool], str] = {}
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
|
||||||
|
update: dict[str, Any] = {}
|
||||||
|
if not state.get("cwd"):
|
||||||
|
update["cwd"] = DOCUMENTS_ROOT
|
||||||
|
|
||||||
|
anon_doc = state.get("kb_anon_doc")
|
||||||
|
if anon_doc:
|
||||||
|
tree_msg = self._render_anon_tree(anon_doc)
|
||||||
|
else:
|
||||||
|
tree_msg = await self._render_kb_tree(state)
|
||||||
|
|
||||||
|
messages = list(state.get("messages") or [])
|
||||||
|
insert_at = max(len(messages) - 1, 0)
|
||||||
|
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||||
|
update["messages"] = messages
|
||||||
|
return update
|
||||||
|
|
||||||
|
def before_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
return None
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
return asyncio.run(self.abefore_agent(state, runtime))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ render
|
||||||
|
|
||||||
|
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
|
||||||
|
path = str(anon_doc.get("path") or "")
|
||||||
|
title = str(anon_doc.get("title") or "uploaded_document")
|
||||||
|
return (
|
||||||
|
"<workspace_tree>\n"
|
||||||
|
"Anonymous session — only one read-only document is available.\n"
|
||||||
|
f"{DOCUMENTS_ROOT}/\n"
|
||||||
|
f" {path} — {title}\n"
|
||||||
|
"</workspace_tree>"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _render_kb_tree(self, state: AgentState) -> str:
|
||||||
|
version = int(state.get("tree_version") or 0)
|
||||||
|
cache_key = (self.search_space_id, version, False)
|
||||||
|
cached = self._cache.get(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, self.search_space_id)
|
||||||
|
doc_rows = await session.execute(
|
||||||
|
select(Document.id, Document.title, Document.folder_id).where(
|
||||||
|
Document.search_space_id == self.search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
docs = list(doc_rows.all())
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("knowledge_tree: DB error %s", exc)
|
||||||
|
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
|
||||||
|
|
||||||
|
rendered = self._format_tree(index, docs)
|
||||||
|
self._cache[cache_key] = rendered
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
|
||||||
|
folder_paths = sorted(set(index.folder_paths.values()))
|
||||||
|
doc_paths = sorted(
|
||||||
|
doc_to_virtual_path(
|
||||||
|
doc_id=row.id,
|
||||||
|
title=str(row.title or "untitled"),
|
||||||
|
folder_id=row.folder_id,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
for row in docs
|
||||||
|
)
|
||||||
|
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||||
|
|
||||||
|
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||||
|
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||||
|
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||||
|
# infer emptiness from indentation alone.
|
||||||
|
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for path in all_paths:
|
||||||
|
depth = (
|
||||||
|
0
|
||||||
|
if path == DOCUMENTS_ROOT
|
||||||
|
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
|
||||||
|
)
|
||||||
|
indent = " " * depth
|
||||||
|
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
|
||||||
|
display = (
|
||||||
|
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||||
|
)
|
||||||
|
if is_dir:
|
||||||
|
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||||
|
lines.append(f"{indent}{display}/ (empty)")
|
||||||
|
else:
|
||||||
|
lines.append(f"{indent}{display}/")
|
||||||
|
else:
|
||||||
|
lines.append(f"{indent}{display}")
|
||||||
|
if len(lines) >= self.max_entries:
|
||||||
|
remaining = len(all_paths) - len(lines)
|
||||||
|
if remaining > 0:
|
||||||
|
lines.append(
|
||||||
|
f"... {remaining} more entries — use "
|
||||||
|
"ls('/documents/<folder>', offset, limit) to expand"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
body = "\n".join(lines)
|
||||||
|
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
|
||||||
|
|
||||||
|
token_count = _count_tokens(rendered, llm=self.llm)
|
||||||
|
if token_count <= self.max_tokens:
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
return self._format_root_summary(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_non_empty_folders(
|
||||||
|
folder_paths: list[str], doc_paths: list[str]
|
||||||
|
) -> set[str]:
|
||||||
|
"""Return the set of folder paths that contain at least one descendant.
|
||||||
|
|
||||||
|
A folder is "non-empty" if any document path or any other folder path
|
||||||
|
is strictly under it. Documents propagate emptiness up to every
|
||||||
|
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||||
|
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||||
|
"""
|
||||||
|
non_empty: set[str] = set()
|
||||||
|
folder_set = set(folder_paths)
|
||||||
|
|
||||||
|
for doc_path in doc_paths:
|
||||||
|
parent = doc_path.rsplit("/", 1)[0]
|
||||||
|
while parent and parent != DOCUMENTS_ROOT:
|
||||||
|
if parent in folder_set:
|
||||||
|
non_empty.add(parent)
|
||||||
|
parent = parent.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
for child in folder_paths:
|
||||||
|
parent = child.rsplit("/", 1)[0]
|
||||||
|
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||||
|
non_empty.add(parent)
|
||||||
|
parent = parent.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
return non_empty
|
||||||
|
|
||||||
|
def _format_root_summary(
|
||||||
|
self, folder_paths: list[str], doc_paths: list[str]
|
||||||
|
) -> str:
|
||||||
|
top_level: dict[str, int] = {}
|
||||||
|
loose_docs = 0
|
||||||
|
for path in doc_paths:
|
||||||
|
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if "/" in rel:
|
||||||
|
top = rel.split("/", 1)[0]
|
||||||
|
top_level[top] = top_level.get(top, 0) + 1
|
||||||
|
else:
|
||||||
|
loose_docs += 1
|
||||||
|
for path in folder_paths:
|
||||||
|
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
continue
|
||||||
|
top = rel.split("/", 1)[0]
|
||||||
|
top_level.setdefault(top, 0)
|
||||||
|
|
||||||
|
lines = [DOCUMENTS_ROOT + "/"]
|
||||||
|
for name in sorted(top_level):
|
||||||
|
count = top_level[name]
|
||||||
|
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
|
||||||
|
if loose_docs:
|
||||||
|
lines.append(
|
||||||
|
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
"Tree is large; use list_tree('/documents/<folder>') to drill in "
|
||||||
|
"or ls('/documents/<folder>', offset, limit) for paginated listings."
|
||||||
|
)
|
||||||
|
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeTreeMiddleware"]
|
||||||
|
|
@ -0,0 +1,613 @@
|
||||||
|
"""Desktop local-folder filesystem backend for deepagents tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import fnmatch
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from collections import deque
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import (
|
||||||
|
EditResult,
|
||||||
|
FileDownloadResponse,
|
||||||
|
FileInfo,
|
||||||
|
FileUploadResponse,
|
||||||
|
GrepMatch,
|
||||||
|
WriteResult,
|
||||||
|
)
|
||||||
|
from deepagents.backends.utils import (
|
||||||
|
create_file_data,
|
||||||
|
format_read_response,
|
||||||
|
perform_string_replacement,
|
||||||
|
)
|
||||||
|
|
||||||
|
_INVALID_PATH = "invalid_path"
|
||||||
|
_FILE_NOT_FOUND = "file_not_found"
|
||||||
|
_IS_DIRECTORY = "is_directory"
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFolderBackend:
|
||||||
|
"""Filesystem backend rooted to a single local folder."""
|
||||||
|
|
||||||
|
def __init__(self, root_path: str) -> None:
|
||||||
|
root = Path(root_path).expanduser().resolve()
|
||||||
|
if not root.exists() or not root.is_dir():
|
||||||
|
msg = f"Local filesystem root does not exist or is not a directory: {root_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._root = root
|
||||||
|
self._locks: dict[str, threading.Lock] = {}
|
||||||
|
self._locks_mu = threading.Lock()
|
||||||
|
|
||||||
|
def _lock_for(self, path: str) -> threading.Lock:
|
||||||
|
with self._locks_mu:
|
||||||
|
if path not in self._locks:
|
||||||
|
self._locks[path] = threading.Lock()
|
||||||
|
return self._locks[path]
|
||||||
|
|
||||||
|
def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path:
|
||||||
|
if not virtual_path.startswith("/"):
|
||||||
|
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
rel = virtual_path.lstrip("/")
|
||||||
|
candidate = self._root if rel == "" else (self._root / rel)
|
||||||
|
resolved = candidate.resolve()
|
||||||
|
if not allow_root and resolved == self._root:
|
||||||
|
msg = "Path must refer to a file or child directory under root"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if not resolved.is_relative_to(self._root):
|
||||||
|
msg = f"Path escapes local filesystem root: {virtual_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_virtual(path: Path, root: Path) -> str:
|
||||||
|
rel = path.relative_to(root).as_posix()
|
||||||
|
return "/" if rel == "." else f"/{rel}"
|
||||||
|
|
||||||
|
def _write_text_atomic(self, path: Path, content: str) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = path.with_suffix(f"{path.suffix}.tmp")
|
||||||
|
temp_path.write_text(content, encoding="utf-8")
|
||||||
|
os.replace(temp_path, path)
|
||||||
|
|
||||||
|
def _acquire_path_locks(self, *paths: str) -> ExitStack:
|
||||||
|
ordered_paths = sorted(set(paths))
|
||||||
|
stack = ExitStack()
|
||||||
|
for path in ordered_paths:
|
||||||
|
stack.enter_context(self._lock_for(path))
|
||||||
|
return stack
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp_page_size(page_size: int) -> int:
|
||||||
|
return max(1, min(page_size, 1000))
|
||||||
|
|
||||||
|
def _read_dir_entries(self, directory_path: str) -> list[dict[str, Any]]:
|
||||||
|
directory = Path(directory_path)
|
||||||
|
try:
|
||||||
|
children = sorted(
|
||||||
|
directory.iterdir(),
|
||||||
|
key=lambda p: (not p.is_dir(), p.name.lower()),
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
for child in children:
|
||||||
|
try:
|
||||||
|
stat_result = child.stat()
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
entries.append(
|
||||||
|
{
|
||||||
|
"path": self._to_virtual(child, self._root),
|
||||||
|
"is_dir": child.is_dir(),
|
||||||
|
"size": stat_result.st_size if child.is_file() else 0,
|
||||||
|
"modified_at": str(stat_result.st_mtime),
|
||||||
|
"absolute_path": str(child),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]:
|
||||||
|
try:
|
||||||
|
target = self._resolve_virtual(path, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
if not target.exists() or not target.is_dir():
|
||||||
|
return []
|
||||||
|
infos: list[FileInfo] = []
|
||||||
|
for child in sorted(
|
||||||
|
target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())
|
||||||
|
):
|
||||||
|
infos.append(
|
||||||
|
FileInfo(
|
||||||
|
path=self._to_virtual(child, self._root),
|
||||||
|
is_dir=child.is_dir(),
|
||||||
|
size=child.stat().st_size if child.is_file() else 0,
|
||||||
|
modified_at=str(child.stat().st_mtime),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return infos
|
||||||
|
|
||||||
|
async def als_info(self, path: str) -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.ls_info, path)
|
||||||
|
|
||||||
|
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: Invalid path '{file_path}'"
|
||||||
|
if not path.exists():
|
||||||
|
return f"Error: File '{file_path}' not found"
|
||||||
|
if not path.is_file():
|
||||||
|
return f"Error: Path '{file_path}' is not a file"
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
file_data = create_file_data(content)
|
||||||
|
return format_read_response(file_data, offset, limit)
|
||||||
|
|
||||||
|
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||||
|
|
||||||
|
def read_raw(self, file_path: str) -> str:
|
||||||
|
"""Read raw file text without line-number formatting."""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: Invalid path '{file_path}'"
|
||||||
|
if not path.exists():
|
||||||
|
return f"Error: File '{file_path}' not found"
|
||||||
|
if not path.is_file():
|
||||||
|
return f"Error: Path '{file_path}' is not a file"
|
||||||
|
return path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
|
||||||
|
async def aread_raw(self, file_path: str) -> str:
|
||||||
|
"""Async variant of read_raw."""
|
||||||
|
return await asyncio.to_thread(self.read_raw, file_path)
|
||||||
|
|
||||||
|
def write(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
lock = self._lock_for(file_path)
|
||||||
|
with lock:
|
||||||
|
if path.exists():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Cannot write to {file_path} because it already exists. "
|
||||||
|
"Read and then make an edit, or write to a new path."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parent = path.parent
|
||||||
|
if not parent.exists() or not parent.is_dir():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: parent directory for '{file_path}' does not exist. "
|
||||||
|
"Create the folder first or write to an existing directory."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._write_text_atomic(path, content)
|
||||||
|
return WriteResult(path=file_path, files_update=None)
|
||||||
|
|
||||||
|
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.write, file_path, content)
|
||||||
|
|
||||||
|
def list_tree(
|
||||||
|
self,
|
||||||
|
path: str = "/",
|
||||||
|
*,
|
||||||
|
max_depth: int | None = 8,
|
||||||
|
page_size: int = 500,
|
||||||
|
include_files: bool = True,
|
||||||
|
include_dirs: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if not include_files and not include_dirs:
|
||||||
|
return {
|
||||||
|
"entries": [],
|
||||||
|
"truncated": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized_depth = None if max_depth is None else max(0, int(max_depth))
|
||||||
|
page_limit = self._clamp_page_size(int(page_size))
|
||||||
|
try:
|
||||||
|
start = self._resolve_virtual(path, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return {"error": f"Error: invalid path '{path}'"}
|
||||||
|
if not start.exists():
|
||||||
|
return {"error": f"Error: path '{path}' not found"}
|
||||||
|
if start.is_file():
|
||||||
|
stat_result = start.stat()
|
||||||
|
if include_files:
|
||||||
|
return {
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"path": self._to_virtual(start, self._root),
|
||||||
|
"is_dir": False,
|
||||||
|
"size": stat_result.st_size,
|
||||||
|
"modified_at": str(stat_result.st_mtime),
|
||||||
|
"depth": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"truncated": False,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"entries": [],
|
||||||
|
"truncated": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
pending_dirs: deque[tuple[str, int]] = deque([(str(start), 0)])
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
truncated = False
|
||||||
|
while pending_dirs and not truncated:
|
||||||
|
next_dir_path, next_depth = pending_dirs.popleft()
|
||||||
|
active_entries = self._read_dir_entries(next_dir_path)
|
||||||
|
for item in active_entries:
|
||||||
|
item_depth = next_depth + 1
|
||||||
|
if normalized_depth is not None and item_depth > normalized_depth:
|
||||||
|
continue
|
||||||
|
if item["is_dir"]:
|
||||||
|
if normalized_depth is None or item_depth <= normalized_depth:
|
||||||
|
pending_dirs.append((item["absolute_path"], item_depth))
|
||||||
|
if include_dirs:
|
||||||
|
entries.append(
|
||||||
|
{
|
||||||
|
"path": item["path"],
|
||||||
|
"is_dir": True,
|
||||||
|
"size": 0,
|
||||||
|
"modified_at": item["modified_at"],
|
||||||
|
"depth": item_depth,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif include_files:
|
||||||
|
entries.append(
|
||||||
|
{
|
||||||
|
"path": item["path"],
|
||||||
|
"is_dir": False,
|
||||||
|
"size": item["size"],
|
||||||
|
"modified_at": item["modified_at"],
|
||||||
|
"depth": item_depth,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if len(entries) >= page_limit:
|
||||||
|
truncated = True
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"entries": entries,
|
||||||
|
"truncated": truncated,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def alist_tree(
|
||||||
|
self,
|
||||||
|
path: str = "/",
|
||||||
|
*,
|
||||||
|
max_depth: int | None = 8,
|
||||||
|
page_size: int = 500,
|
||||||
|
include_files: bool = True,
|
||||||
|
include_dirs: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.list_tree,
|
||||||
|
path,
|
||||||
|
max_depth=max_depth,
|
||||||
|
page_size=page_size,
|
||||||
|
include_files=include_files,
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def move(
|
||||||
|
self,
|
||||||
|
source_path: str,
|
||||||
|
destination_path: str,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> WriteResult:
|
||||||
|
try:
|
||||||
|
source = self._resolve_virtual(source_path)
|
||||||
|
destination = self._resolve_virtual(destination_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: invalid source '{source_path}' or destination "
|
||||||
|
f"'{destination_path}' path"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if source == destination:
|
||||||
|
return WriteResult(error="Error: source and destination paths are the same")
|
||||||
|
with self._acquire_path_locks(source_path, destination_path):
|
||||||
|
if not source.exists():
|
||||||
|
return WriteResult(
|
||||||
|
error=f"Error: source path '{source_path}' not found"
|
||||||
|
)
|
||||||
|
if destination.exists():
|
||||||
|
if not overwrite:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: destination path '{destination_path}' already exists. "
|
||||||
|
"Set overwrite=True to replace files."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if source.is_dir() or destination.is_dir():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
"Error: overwrite=True is only supported for file-to-file moves."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
if overwrite:
|
||||||
|
os.replace(source, destination)
|
||||||
|
else:
|
||||||
|
source.rename(destination)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(
|
||||||
|
error=f"Error: failed to move '{source_path}': {exc}"
|
||||||
|
)
|
||||||
|
return WriteResult(
|
||||||
|
path=self._to_virtual(destination, self._root), files_update=None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def amove(
|
||||||
|
self,
|
||||||
|
source_path: str,
|
||||||
|
destination_path: str,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.move, source_path, destination_path, overwrite
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> WriteResult:
|
||||||
|
"""Hard-delete a single file under root.
|
||||||
|
|
||||||
|
Refuses directories, root, and missing paths. Roughly mirrors POSIX
|
||||||
|
``rm path``; ``-r`` recursion and glob expansion are explicitly
|
||||||
|
out of scope.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
with self._lock_for(file_path):
|
||||||
|
if not path.exists():
|
||||||
|
return WriteResult(error=f"Error: File '{file_path}' not found")
|
||||||
|
if path.is_dir():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: '{file_path}' is a directory. "
|
||||||
|
"Use rmdir for empty directories."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.unlink(path)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(
|
||||||
|
error=f"Error: failed to delete '{file_path}': {exc}"
|
||||||
|
)
|
||||||
|
return WriteResult(path=file_path, files_update=None)
|
||||||
|
|
||||||
|
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.delete_file, file_path)
|
||||||
|
|
||||||
|
def rmdir(self, dir_path: str) -> WriteResult:
|
||||||
|
"""Hard-delete an empty directory under root.
|
||||||
|
|
||||||
|
Refuses files, root, missing paths, and non-empty directories.
|
||||||
|
``os.rmdir`` is naturally empty-only; we pre-check so the error is
|
||||||
|
clearer for the agent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(dir_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
|
||||||
|
with self._lock_for(dir_path):
|
||||||
|
if not path.exists():
|
||||||
|
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
|
||||||
|
if not path.is_dir():
|
||||||
|
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
|
||||||
|
try:
|
||||||
|
next(path.iterdir())
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: directory '{dir_path}' is not empty. "
|
||||||
|
"Remove its contents first."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.rmdir(path)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
|
||||||
|
return WriteResult(path=dir_path, files_update=None)
|
||||||
|
|
||||||
|
async def armdir(self, dir_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||||
|
|
||||||
|
def edit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return EditResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
lock = self._lock_for(file_path)
|
||||||
|
with lock:
|
||||||
|
if not path.exists() or not path.is_file():
|
||||||
|
return EditResult(error=f"Error: File '{file_path}' not found")
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
result = perform_string_replacement(
|
||||||
|
content, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return EditResult(error=result)
|
||||||
|
updated_content, occurrences = result
|
||||||
|
self._write_text_atomic(path, updated_content)
|
||||||
|
return EditResult(
|
||||||
|
path=file_path, files_update=None, occurrences=int(occurrences)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aedit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.edit, file_path, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
|
||||||
|
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
try:
|
||||||
|
base = self._resolve_virtual(path, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if pattern.startswith("/"):
|
||||||
|
search_base = self._root
|
||||||
|
normalized_pattern = pattern.lstrip("/")
|
||||||
|
else:
|
||||||
|
search_base = base
|
||||||
|
normalized_pattern = pattern
|
||||||
|
|
||||||
|
matches: list[FileInfo] = []
|
||||||
|
for hit in search_base.glob(normalized_pattern):
|
||||||
|
try:
|
||||||
|
resolved = hit.resolve()
|
||||||
|
if not resolved.is_relative_to(self._root):
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
matches.append(
|
||||||
|
FileInfo(
|
||||||
|
path=self._to_virtual(resolved, self._root),
|
||||||
|
is_dir=resolved.is_dir(),
|
||||||
|
size=resolved.stat().st_size if resolved.is_file() else 0,
|
||||||
|
modified_at=str(resolved.stat().st_mtime),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||||
|
|
||||||
|
def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]:
|
||||||
|
base_virtual = path or "/"
|
||||||
|
try:
|
||||||
|
base = self._resolve_virtual(base_virtual, allow_root=True)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
if not base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidates = [p for p in base.rglob("*") if p.is_file()]
|
||||||
|
if glob:
|
||||||
|
candidates = [
|
||||||
|
p
|
||||||
|
for p in candidates
|
||||||
|
if fnmatch.fnmatch(self._to_virtual(p, self._root), glob)
|
||||||
|
or fnmatch.fnmatch(p.name, glob)
|
||||||
|
]
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
def grep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
if not pattern:
|
||||||
|
return "Error: pattern cannot be empty"
|
||||||
|
matches: list[GrepMatch] = []
|
||||||
|
for file_path in self._iter_candidate_files(path, glob):
|
||||||
|
try:
|
||||||
|
lines = file_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
for idx, line in enumerate(lines, start=1):
|
||||||
|
if pattern in line:
|
||||||
|
matches.append(
|
||||||
|
GrepMatch(
|
||||||
|
path=self._to_virtual(file_path, self._root),
|
||||||
|
line=idx,
|
||||||
|
text=line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
async def agrep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||||
|
|
||||||
|
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||||
|
responses: list[FileUploadResponse] = []
|
||||||
|
for virtual_path, content in files:
|
||||||
|
try:
|
||||||
|
target = self._resolve_virtual(virtual_path)
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = target.with_suffix(f"{target.suffix}.tmp")
|
||||||
|
temp_path.write_bytes(content)
|
||||||
|
os.replace(temp_path, target)
|
||||||
|
responses.append(FileUploadResponse(path=virtual_path, error=None))
|
||||||
|
except FileNotFoundError:
|
||||||
|
responses.append(
|
||||||
|
FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND)
|
||||||
|
)
|
||||||
|
except IsADirectoryError:
|
||||||
|
responses.append(
|
||||||
|
FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
responses.append(
|
||||||
|
FileUploadResponse(path=virtual_path, error=_INVALID_PATH)
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def aupload_files(
|
||||||
|
self, files: list[tuple[str, bytes]]
|
||||||
|
) -> list[FileUploadResponse]:
|
||||||
|
return await asyncio.to_thread(self.upload_files, files)
|
||||||
|
|
||||||
|
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
responses: list[FileDownloadResponse] = []
|
||||||
|
for virtual_path in paths:
|
||||||
|
try:
|
||||||
|
target = self._resolve_virtual(virtual_path)
|
||||||
|
if not target.exists():
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=None, error=_FILE_NOT_FOUND
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if target.is_dir():
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=None, error=_IS_DIRECTORY
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=target.read_bytes(), error=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=None, error=_INVALID_PATH
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
return await asyncio.to_thread(self.download_files, paths)
|
||||||
|
|
@ -0,0 +1,489 @@
|
||||||
|
"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import (
|
||||||
|
EditResult,
|
||||||
|
FileDownloadResponse,
|
||||||
|
FileInfo,
|
||||||
|
FileUploadResponse,
|
||||||
|
GrepMatch,
|
||||||
|
WriteResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend
|
||||||
|
|
||||||
|
_INVALID_PATH = "invalid_path"
|
||||||
|
_FILE_NOT_FOUND = "file_not_found"
|
||||||
|
_IS_DIRECTORY = "is_directory"
|
||||||
|
|
||||||
|
|
||||||
|
class MultiRootLocalFolderBackend:
|
||||||
|
"""Route filesystem operations to one of several mounted local roots.
|
||||||
|
|
||||||
|
Virtual paths are namespaced as:
|
||||||
|
- `/<mount>/...`
|
||||||
|
where `<mount>` is derived from each selected root folder name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mounts: tuple[tuple[str, str], ...]) -> None:
|
||||||
|
if not mounts:
|
||||||
|
msg = "At least one local mount is required"
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._mount_to_backend: dict[str, LocalFolderBackend] = {}
|
||||||
|
for raw_mount, raw_root in mounts:
|
||||||
|
mount = raw_mount.strip()
|
||||||
|
if not mount:
|
||||||
|
msg = "Mount id cannot be empty"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if mount in self._mount_to_backend:
|
||||||
|
msg = f"Duplicate mount id: {mount}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
normalized_root = str(Path(raw_root).expanduser().resolve())
|
||||||
|
self._mount_to_backend[mount] = LocalFolderBackend(normalized_root)
|
||||||
|
self._mount_order = tuple(self._mount_to_backend.keys())
|
||||||
|
|
||||||
|
def list_mounts(self) -> tuple[str, ...]:
|
||||||
|
return self._mount_order
|
||||||
|
|
||||||
|
def default_mount(self) -> str:
|
||||||
|
return self._mount_order[0]
|
||||||
|
|
||||||
|
def _mount_error(self) -> str:
|
||||||
|
mounts = ", ".join(f"/{mount}" for mount in self._mount_order)
|
||||||
|
return (
|
||||||
|
"Path must start with one of the selected folders: "
|
||||||
|
f"{mounts}. Example: /{self._mount_order[0]}/file.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_mount_path(self, virtual_path: str) -> tuple[str, str]:
|
||||||
|
if not virtual_path.startswith("/"):
|
||||||
|
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
rel = virtual_path.lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
raise ValueError(self._mount_error())
|
||||||
|
mount, _, remainder = rel.partition("/")
|
||||||
|
backend = self._mount_to_backend.get(mount)
|
||||||
|
if backend is None:
|
||||||
|
raise ValueError(self._mount_error())
|
||||||
|
local_path = f"/{remainder}" if remainder else "/"
|
||||||
|
return mount, local_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prefix_mount_path(mount: str, local_path: str) -> str:
|
||||||
|
if local_path == "/":
|
||||||
|
return f"/{mount}"
|
||||||
|
return f"/{mount}{local_path}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_value(item: Any, key: str) -> Any:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return item.get(key)
|
||||||
|
return getattr(item, key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_str(cls, item: Any, key: str) -> str:
|
||||||
|
value = cls._get_value(item, key)
|
||||||
|
return value if isinstance(value, str) else ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_int(cls, item: Any, key: str) -> int:
|
||||||
|
value = cls._get_value(item, key)
|
||||||
|
return int(value) if isinstance(value, int | float) else 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_bool(cls, item: Any, key: str) -> bool:
|
||||||
|
value = cls._get_value(item, key)
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
def _list_mount_roots(self) -> list[FileInfo]:
|
||||||
|
return [
|
||||||
|
FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0")
|
||||||
|
for mount in self._mount_order
|
||||||
|
]
|
||||||
|
|
||||||
|
def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]:
|
||||||
|
transformed: list[FileInfo] = []
|
||||||
|
for info in infos:
|
||||||
|
transformed.append(
|
||||||
|
FileInfo(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(info, "path")),
|
||||||
|
is_dir=self._get_bool(info, "is_dir"),
|
||||||
|
size=self._get_int(info, "size"),
|
||||||
|
modified_at=self._get_str(info, "modified_at"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]:
|
||||||
|
if path == "/":
|
||||||
|
return self._list_mount_roots()
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
return self._transform_infos(
|
||||||
|
mount, self._mount_to_backend[mount].ls_info(local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def als_info(self, path: str) -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.ls_info, path)
|
||||||
|
|
||||||
|
def list_tree(
|
||||||
|
self,
|
||||||
|
path: str = "/",
|
||||||
|
*,
|
||||||
|
max_depth: int | None = 8,
|
||||||
|
page_size: int = 500,
|
||||||
|
include_files: bool = True,
|
||||||
|
include_dirs: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if path == "/":
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
"path": f"/{mount}",
|
||||||
|
"is_dir": True,
|
||||||
|
"size": 0,
|
||||||
|
"modified_at": "0",
|
||||||
|
"depth": 0,
|
||||||
|
}
|
||||||
|
for mount in self._mount_order
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"entries": entries if include_dirs else [],
|
||||||
|
"truncated": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return {"error": f"Error: {exc}"}
|
||||||
|
|
||||||
|
result = self._mount_to_backend[mount].list_tree(
|
||||||
|
local_path,
|
||||||
|
max_depth=max_depth,
|
||||||
|
page_size=page_size,
|
||||||
|
include_files=include_files,
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
return result
|
||||||
|
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
raw_path = self._get_str(entry, "path")
|
||||||
|
entries.append(
|
||||||
|
{
|
||||||
|
"path": self._prefix_mount_path(mount, raw_path),
|
||||||
|
"is_dir": self._get_bool(entry, "is_dir"),
|
||||||
|
"size": self._get_int(entry, "size"),
|
||||||
|
"modified_at": self._get_str(entry, "modified_at"),
|
||||||
|
"depth": self._get_int(entry, "depth"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"entries": entries,
|
||||||
|
"truncated": self._get_bool(result, "truncated"),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def alist_tree(
|
||||||
|
self,
|
||||||
|
path: str = "/",
|
||||||
|
*,
|
||||||
|
max_depth: int | None = 8,
|
||||||
|
page_size: int = 500,
|
||||||
|
include_files: bool = True,
|
||||||
|
include_dirs: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.list_tree,
|
||||||
|
path,
|
||||||
|
max_depth=max_depth,
|
||||||
|
page_size=page_size,
|
||||||
|
include_files=include_files,
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
return self._mount_to_backend[mount].read(local_path, offset, limit)
|
||||||
|
|
||||||
|
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||||
|
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||||
|
|
||||||
|
def read_raw(self, file_path: str) -> str:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
return self._mount_to_backend[mount].read_raw(local_path)
|
||||||
|
|
||||||
|
async def aread_raw(self, file_path: str) -> str:
|
||||||
|
return await asyncio.to_thread(self.read_raw, file_path)
|
||||||
|
|
||||||
|
def write(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].write(local_path, content)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.write, file_path, content)
|
||||||
|
|
||||||
|
def move(
|
||||||
|
self,
|
||||||
|
source_path: str,
|
||||||
|
destination_path: str,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> WriteResult:
|
||||||
|
try:
|
||||||
|
source_mount, source_local_path = self._split_mount_path(source_path)
|
||||||
|
destination_mount, destination_local_path = self._split_mount_path(
|
||||||
|
destination_path
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
if source_mount != destination_mount:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
"Error: cross-mount moves are not supported. "
|
||||||
|
"Source and destination must be under the same mounted root."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = self._mount_to_backend[source_mount].move(
|
||||||
|
source_local_path,
|
||||||
|
destination_local_path,
|
||||||
|
overwrite=overwrite,
|
||||||
|
)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(source_mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def amove(
|
||||||
|
self,
|
||||||
|
source_path: str,
|
||||||
|
destination_path: str,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.move,
|
||||||
|
source_path,
|
||||||
|
destination_path,
|
||||||
|
overwrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].delete_file(local_path)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.delete_file, file_path)
|
||||||
|
|
||||||
|
def rmdir(self, dir_path: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(dir_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
if local_path == "/":
|
||||||
|
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
|
||||||
|
result = self._mount_to_backend[mount].rmdir(local_path)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def armdir(self, dir_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||||
|
|
||||||
|
def edit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return EditResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].edit(
|
||||||
|
local_path, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def aedit(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.edit, file_path, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
|
||||||
|
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
if path == "/":
|
||||||
|
prefixed_results: list[FileInfo] = []
|
||||||
|
if pattern.startswith("/"):
|
||||||
|
mount, _, remainder = pattern.lstrip("/").partition("/")
|
||||||
|
backend = self._mount_to_backend.get(mount)
|
||||||
|
if not backend:
|
||||||
|
return []
|
||||||
|
local_pattern = f"/{remainder}" if remainder else "/"
|
||||||
|
return self._transform_infos(
|
||||||
|
mount, backend.glob_info(local_pattern, path="/")
|
||||||
|
)
|
||||||
|
for mount, backend in self._mount_to_backend.items():
|
||||||
|
prefixed_results.extend(
|
||||||
|
self._transform_infos(mount, backend.glob_info(pattern, path="/"))
|
||||||
|
)
|
||||||
|
return prefixed_results
|
||||||
|
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
return self._transform_infos(
|
||||||
|
mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||||
|
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||||
|
|
||||||
|
def grep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
if not pattern:
|
||||||
|
return "Error: pattern cannot be empty"
|
||||||
|
if path is None or path == "/":
|
||||||
|
all_matches: list[GrepMatch] = []
|
||||||
|
for mount, backend in self._mount_to_backend.items():
|
||||||
|
result = backend.grep_raw(pattern, path="/", glob=glob)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
all_matches.extend(
|
||||||
|
[
|
||||||
|
GrepMatch(
|
||||||
|
path=self._prefix_mount_path(
|
||||||
|
mount, self._get_str(match, "path")
|
||||||
|
),
|
||||||
|
line=self._get_int(match, "line"),
|
||||||
|
text=self._get_str(match, "text"),
|
||||||
|
)
|
||||||
|
for match in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return all_matches
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
|
||||||
|
result = self._mount_to_backend[mount].grep_raw(
|
||||||
|
pattern, path=local_path, glob=glob
|
||||||
|
)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
return [
|
||||||
|
GrepMatch(
|
||||||
|
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
|
||||||
|
line=self._get_int(match, "line"),
|
||||||
|
text=self._get_str(match, "text"),
|
||||||
|
)
|
||||||
|
for match in result
|
||||||
|
]
|
||||||
|
|
||||||
|
async def agrep_raw(
|
||||||
|
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||||
|
|
||||||
|
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||||
|
grouped: dict[str, list[tuple[str, bytes]]] = {}
|
||||||
|
invalid: list[FileUploadResponse] = []
|
||||||
|
for virtual_path, content in files:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(virtual_path)
|
||||||
|
except ValueError:
|
||||||
|
invalid.append(
|
||||||
|
FileUploadResponse(path=virtual_path, error=_INVALID_PATH)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
grouped.setdefault(mount, []).append((local_path, content))
|
||||||
|
|
||||||
|
responses = list(invalid)
|
||||||
|
for mount, mount_files in grouped.items():
|
||||||
|
result = self._mount_to_backend[mount].upload_files(mount_files)
|
||||||
|
responses.extend(
|
||||||
|
[
|
||||||
|
FileUploadResponse(
|
||||||
|
path=self._prefix_mount_path(
|
||||||
|
mount, self._get_str(item, "path")
|
||||||
|
),
|
||||||
|
error=self._get_str(item, "error") or None,
|
||||||
|
)
|
||||||
|
for item in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def aupload_files(
|
||||||
|
self, files: list[tuple[str, bytes]]
|
||||||
|
) -> list[FileUploadResponse]:
|
||||||
|
return await asyncio.to_thread(self.upload_files, files)
|
||||||
|
|
||||||
|
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
grouped: dict[str, list[str]] = {}
|
||||||
|
invalid: list[FileDownloadResponse] = []
|
||||||
|
for virtual_path in paths:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(virtual_path)
|
||||||
|
except ValueError:
|
||||||
|
invalid.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=virtual_path, content=None, error=_INVALID_PATH
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
grouped.setdefault(mount, []).append(local_path)
|
||||||
|
|
||||||
|
responses = list(invalid)
|
||||||
|
for mount, mount_paths in grouped.items():
|
||||||
|
result = self._mount_to_backend[mount].download_files(mount_paths)
|
||||||
|
responses.extend(
|
||||||
|
[
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=self._prefix_mount_path(
|
||||||
|
mount, self._get_str(item, "path")
|
||||||
|
),
|
||||||
|
content=self._get_value(item, "content"),
|
||||||
|
error=self._get_str(item, "error") or None,
|
||||||
|
)
|
||||||
|
for item in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
return await asyncio.to_thread(self.download_files, paths)
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""
|
||||||
|
``_noop`` provider-compatibility tool + injection middleware.
|
||||||
|
|
||||||
|
Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has
|
||||||
|
empty ``tools`` but the message history includes prior ``tool_calls`` —
|
||||||
|
they treat that shape as malformed even though it's perfectly valid
|
||||||
|
LangChain. SurfSense hits this on the compaction summarize call (no
|
||||||
|
tools, history full of tool calls).
|
||||||
|
|
||||||
|
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``,
|
||||||
|
which discovered and codified the workaround: inject a no-op tool *only*
|
||||||
|
on those provider shapes so the request validates without ever being
|
||||||
|
called.
|
||||||
|
|
||||||
|
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
|
||||||
|
if the request has zero tools but the last AI message in history includes
|
||||||
|
``tool_calls``. If yes, it injects the ``_noop`` tool only — never
|
||||||
|
globally — mirroring OpenCode's gating exactly. The :func:`noop_tool`
|
||||||
|
returns empty content when called (which it should never be in
|
||||||
|
practice).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NOOP_TOOL_NAME = "_noop"
|
||||||
|
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
|
||||||
|
|
||||||
|
|
||||||
|
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
||||||
|
def noop_tool() -> str:
|
||||||
|
"""Return empty content. Never expected to be called."""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
# Provider markers that benefit from ``_noop`` injection. These match
|
||||||
|
# OpenCode's gating list (``llm.ts:209-228``). We also accept any string
|
||||||
|
# containing one of these substrings so e.g. ``litellm`` matches
|
||||||
|
# ``ChatLiteLLM``.
|
||||||
|
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
|
||||||
|
"litellm",
|
||||||
|
"bedrock",
|
||||||
|
"copilot",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_needs_noop(model: Any) -> bool:
|
||||||
|
"""Heuristic: does this model's provider need the _noop injection?"""
|
||||||
|
try:
|
||||||
|
ls_params = model._get_ls_params()
|
||||||
|
provider = str(ls_params.get("ls_provider", "")).lower()
|
||||||
|
except Exception:
|
||||||
|
provider = ""
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
cls_name = type(model).__name__.lower()
|
||||||
|
provider = cls_name
|
||||||
|
|
||||||
|
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
|
||||||
|
|
||||||
|
|
||||||
|
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
return bool(msg.tool_calls)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class NoopInjectionMiddleware(
|
||||||
|
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||||
|
):
|
||||||
|
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
||||||
|
|
||||||
|
The check fires per model call, not at agent build time, because the
|
||||||
|
summarization path generates a no-tool subcall at runtime. The
|
||||||
|
extra tool is appended to ``request.tools`` as an instance — the
|
||||||
|
actual ``langchain_core.tools.BaseTool`` is bound on every call site
|
||||||
|
that creates the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._noop_tool = noop_tool_instance or noop_tool
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
|
||||||
|
if request.tools:
|
||||||
|
return False
|
||||||
|
if not _last_ai_has_tool_calls(request.messages):
|
||||||
|
return False
|
||||||
|
return _provider_needs_noop(request.model)
|
||||||
|
|
||||||
|
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||||
|
return request.override(tools=[self._noop_tool])
|
||||||
|
|
||||||
|
def wrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||||
|
) -> Any:
|
||||||
|
if self._should_inject(request):
|
||||||
|
logger.debug("Injecting _noop tool for provider compatibility")
|
||||||
|
return handler(self._augmented(request))
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[
|
||||||
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||||
|
],
|
||||||
|
) -> Any:
|
||||||
|
if self._should_inject(request):
|
||||||
|
logger.debug("Injecting _noop tool for provider compatibility")
|
||||||
|
return await handler(self._augmented(request))
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"NOOP_TOOL_DESCRIPTION",
|
||||||
|
"NOOP_TOOL_NAME",
|
||||||
|
"NoopInjectionMiddleware",
|
||||||
|
"_provider_needs_noop",
|
||||||
|
"noop_tool",
|
||||||
|
]
|
||||||
202
surfsense_backend/app/agents/new_chat/middleware/otel_span.py
Normal file
202
surfsense_backend/app/agents/new_chat/middleware/otel_span.py
Normal file
|
|
@ -0,0 +1,202 @@
|
||||||
|
"""
|
||||||
|
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
|
||||||
|
|
||||||
|
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
|
||||||
|
executions) with OTel spans, attaching low-cardinality span names and
|
||||||
|
high-cardinality identifiers as attributes.
|
||||||
|
|
||||||
|
This middleware is intentionally a thin adapter over
|
||||||
|
:mod:`app.observability.otel`; when OTel is not configured all spans
|
||||||
|
collapse to no-ops and the wrapper adds <1µs overhead per call. When
|
||||||
|
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
|
||||||
|
model and tool call gets a span with the standard attributes our
|
||||||
|
dashboards expect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover — type-only
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ToolCallRequest,
|
||||||
|
)
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OtelSpanMiddleware(AgentMiddleware):
|
||||||
|
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
|
||||||
|
|
||||||
|
Should be placed near the **outer** end of the middleware list so
|
||||||
|
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
|
||||||
|
model.call spans for ``N`` retry attempts) but inside any concurrency/
|
||||||
|
auth gate. Empirically this means **between** ``BusyMutex`` and
|
||||||
|
``RetryAfter``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._instrumentation_name = instrumentation_name
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Model call spans
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
|
||||||
|
) -> ModelResponse | AIMessage | Any:
|
||||||
|
if not ot.is_enabled():
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
model_id, provider = _resolve_model_attrs(request)
|
||||||
|
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
|
||||||
|
try:
|
||||||
|
result = await handler(request)
|
||||||
|
except Exception:
|
||||||
|
# span context manager records + re-raises
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
_annotate_model_response(sp, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool call spans
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
if not ot.is_enabled():
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
tool_name = _resolve_tool_name(request)
|
||||||
|
input_size = _resolve_input_size(request)
|
||||||
|
|
||||||
|
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
|
||||||
|
result = await handler(request)
|
||||||
|
_annotate_tool_result(sp, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
|
||||||
|
# a real model/tool call).
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
|
||||||
|
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
|
||||||
|
model_id: str | None = None
|
||||||
|
provider: str | None = None
|
||||||
|
try:
|
||||||
|
model = getattr(request, "model", None)
|
||||||
|
if model is None:
|
||||||
|
return None, None
|
||||||
|
# langchain BaseChatModel exposes a few different identifiers
|
||||||
|
for attr in ("model_name", "model", "model_id"):
|
||||||
|
value = getattr(model, attr, None)
|
||||||
|
if value:
|
||||||
|
model_id = str(value)
|
||||||
|
break
|
||||||
|
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
|
||||||
|
for attr in ("provider", "_llm_type"):
|
||||||
|
value = getattr(model, attr, None)
|
||||||
|
if value:
|
||||||
|
provider = str(value)
|
||||||
|
break
|
||||||
|
except Exception: # pragma: no cover — defensive
|
||||||
|
pass
|
||||||
|
return model_id, provider
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tool_name(request: Any) -> str:
|
||||||
|
try:
|
||||||
|
tool = getattr(request, "tool", None)
|
||||||
|
if tool is not None:
|
||||||
|
name = getattr(tool, "name", None)
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
return name
|
||||||
|
# Fall back to the tool_call dict
|
||||||
|
call = getattr(request, "tool_call", None) or {}
|
||||||
|
name = call.get("name") if isinstance(call, dict) else None
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
return name
|
||||||
|
except Exception: # pragma: no cover — defensive
|
||||||
|
pass
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_input_size(request: Any) -> int | None:
|
||||||
|
try:
|
||||||
|
call = getattr(request, "tool_call", None)
|
||||||
|
if not isinstance(call, dict) or not call:
|
||||||
|
return None
|
||||||
|
args = call.get("args")
|
||||||
|
if args is None:
|
||||||
|
return None
|
||||||
|
return len(repr(args))
|
||||||
|
except Exception: # pragma: no cover — defensive
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _annotate_model_response(span: Any, result: Any) -> None:
|
||||||
|
"""Best-effort: attach prompt/completion token counts when available."""
|
||||||
|
try:
|
||||||
|
# ModelResponse may be a dataclass with .result containing AIMessage
|
||||||
|
msg: Any
|
||||||
|
if isinstance(result, AIMessage):
|
||||||
|
msg = result
|
||||||
|
else:
|
||||||
|
inner = getattr(result, "result", None)
|
||||||
|
msg = inner[-1] if isinstance(inner, list) and inner else inner
|
||||||
|
if msg is None:
|
||||||
|
return
|
||||||
|
usage = getattr(msg, "usage_metadata", None) or {}
|
||||||
|
if isinstance(usage, dict):
|
||||||
|
if (n := usage.get("input_tokens")) is not None:
|
||||||
|
span.set_attribute("tokens.prompt", int(n))
|
||||||
|
if (n := usage.get("output_tokens")) is not None:
|
||||||
|
span.set_attribute("tokens.completion", int(n))
|
||||||
|
if (n := usage.get("total_tokens")) is not None:
|
||||||
|
span.set_attribute("tokens.total", int(n))
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||||
|
span.set_attribute("model.tool_calls", len(tool_calls))
|
||||||
|
except Exception: # pragma: no cover — defensive
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _annotate_tool_result(span: Any, result: Any) -> None:
|
||||||
|
try:
|
||||||
|
if isinstance(result, ToolMessage):
|
||||||
|
content = (
|
||||||
|
result.content
|
||||||
|
if isinstance(result.content, str)
|
||||||
|
else repr(result.content)
|
||||||
|
)
|
||||||
|
span.set_attribute("tool.output.size", len(content))
|
||||||
|
status = getattr(result, "status", None)
|
||||||
|
if isinstance(status, str):
|
||||||
|
span.set_attribute("tool.status", status)
|
||||||
|
kwargs = getattr(result, "additional_kwargs", None) or {}
|
||||||
|
if isinstance(kwargs, dict) and kwargs.get("error"):
|
||||||
|
span.set_attribute("tool.error", True)
|
||||||
|
except Exception: # pragma: no cover — defensive
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["OtelSpanMiddleware"]
|
||||||
358
surfsense_backend/app/agents/new_chat/middleware/permission.py
Normal file
358
surfsense_backend/app/agents/new_chat/middleware/permission.py
Normal file
|
|
@ -0,0 +1,358 @@
|
||||||
|
"""
|
||||||
|
PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback.
|
||||||
|
|
||||||
|
LangChain's :class:`HumanInTheLoopMiddleware` only supports a static
|
||||||
|
"this tool always asks" decision per tool. There's no rule-based
|
||||||
|
allow/deny/ask layered ruleset, no glob patterns, no per-search-space or
|
||||||
|
per-thread overrides, and no auto-deny synthesis.
|
||||||
|
|
||||||
|
This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts``
|
||||||
|
ruleset model on top of SurfSense's existing ``interrupt({type, action,
|
||||||
|
context})`` payload shape (see ``app/agents/new_chat/tools/hitl.py``) so
|
||||||
|
the frontend keeps working unchanged.
|
||||||
|
|
||||||
|
Operation:
|
||||||
|
1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``.
|
||||||
|
2. For each call, the middleware builds a list of ``patterns`` (the
|
||||||
|
tool name plus any tool-specific patterns from the resolver). It
|
||||||
|
evaluates each pattern against the layered rulesets and aggregates
|
||||||
|
the results: ``deny`` > ``ask`` > ``allow``.
|
||||||
|
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||||
|
containing a :class:`StreamingError`.
|
||||||
|
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
|
||||||
|
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
|
||||||
|
- ``once``: proceed.
|
||||||
|
- ``always``: also persist allow rules for ``request.always`` patterns.
|
||||||
|
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||||
|
- ``reject`` w/ feedback: raise :class:`CorrectedError`.
|
||||||
|
5. On ``allow``: proceed unchanged.
|
||||||
|
|
||||||
|
The middleware also performs a *pre-model* tool-filter step (the
|
||||||
|
``before_model`` hook) so globally denied tools are stripped from the
|
||||||
|
exposed tool list before the model gets to see them. This mirrors
|
||||||
|
OpenCode's ``Permission.disabled`` and dramatically reduces the chance
|
||||||
|
the model emits a deny-only call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import (
|
||||||
|
CorrectedError,
|
||||||
|
RejectedError,
|
||||||
|
StreamingError,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.permissions import (
|
||||||
|
Rule,
|
||||||
|
Ruleset,
|
||||||
|
aggregate_action,
|
||||||
|
evaluate_many,
|
||||||
|
)
|
||||||
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of
|
||||||
|
# patterns to evaluate. The first pattern is conventionally the bare
|
||||||
|
# tool name; later entries narrow down to specific resources.
|
||||||
|
PatternResolver = Callable[[dict[str, Any]], list[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_pattern_resolver(name: str) -> PatternResolver:
|
||||||
|
def _resolve(args: dict[str, Any]) -> list[str]:
|
||||||
|
# Bare name covers the default catch-all; primary-arg fallbacks
|
||||||
|
# are best added per-tool by callers.
|
||||||
|
del args
|
||||||
|
return [name]
|
||||||
|
|
||||||
|
return _resolve
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Allow/deny/ask layer over the agent's tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rulesets: Layered rulesets to evaluate. Earlier entries are
|
||||||
|
overridden by later ones (last-match-wins). Typical layering:
|
||||||
|
``defaults < global < space < thread < runtime_approved``.
|
||||||
|
pattern_resolvers: Optional per-tool callables that return a list
|
||||||
|
of patterns to evaluate. When a tool isn't listed, the bare
|
||||||
|
tool name is used as the only pattern.
|
||||||
|
runtime_ruleset: Mutable :class:`Ruleset` that the middleware
|
||||||
|
extends in-place when the user replies ``"always"`` to an
|
||||||
|
ask interrupt. Reused across all calls in the same agent
|
||||||
|
instance so newly-allowed rules apply to subsequent calls.
|
||||||
|
always_emit_interrupt_payload: If True, every ask uses the
|
||||||
|
SurfSense interrupt wire format (default). Set False to
|
||||||
|
disable interrupts and treat ``ask`` as ``deny`` for
|
||||||
|
non-interactive deployments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
rulesets: list[Ruleset] | None = None,
|
||||||
|
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||||
|
runtime_ruleset: Ruleset | None = None,
|
||||||
|
always_emit_interrupt_payload: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||||
|
self._pattern_resolvers: dict[str, PatternResolver] = dict(
|
||||||
|
pattern_resolvers or {}
|
||||||
|
)
|
||||||
|
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
|
||||||
|
origin="runtime_approved"
|
||||||
|
)
|
||||||
|
self._emit_interrupt = always_emit_interrupt_payload
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool-filter step (mirrors OpenCode's ``Permission.disabled``)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _globally_denied(self, tool_name: str) -> bool:
|
||||||
|
"""Return True if a deny rule with no narrowing pattern matches."""
|
||||||
|
rules = evaluate_many(tool_name, ["*"], *self._all_rulesets())
|
||||||
|
return aggregate_action(rules) == "deny"
|
||||||
|
|
||||||
|
def _all_rulesets(self) -> list[Ruleset]:
|
||||||
|
return [*self._static_rulesets, self._runtime_ruleset]
|
||||||
|
|
||||||
|
# NOTE: ``before_model`` filtering of the tools list is left to the
|
||||||
|
# agent factory. This middleware only blocks at execution time — and
|
||||||
|
# only via the rule-evaluator path, not by mutating ``request.tools``.
|
||||||
|
# Mutating ``request.tools`` per-call would invalidate provider
|
||||||
|
# prompt-cache prefixes (see Operational risks: prompt-cache regression).
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool-call evaluation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]:
|
||||||
|
resolver = self._pattern_resolvers.get(
|
||||||
|
tool_name, _default_pattern_resolver(tool_name)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
patterns = resolver(args or {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Pattern resolver for %s raised; using bare name", tool_name
|
||||||
|
)
|
||||||
|
patterns = [tool_name]
|
||||||
|
if not patterns:
|
||||||
|
patterns = [tool_name]
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def _evaluate(
|
||||||
|
self, tool_name: str, args: dict[str, Any]
|
||||||
|
) -> tuple[str, list[str], list[Rule]]:
|
||||||
|
patterns = self._resolve_patterns(tool_name, args)
|
||||||
|
rules = evaluate_many(tool_name, patterns, *self._all_rulesets())
|
||||||
|
action = aggregate_action(rules)
|
||||||
|
return action, patterns, rules
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# HITL ask flow — SurfSense wire format
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _raise_interrupt(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
patterns: list[str],
|
||||||
|
rules: list[Rule],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Block on user approval via SurfSense's ``interrupt`` shape."""
|
||||||
|
if not self._emit_interrupt:
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
# ``params`` (NOT ``args``) is what SurfSense's streaming
|
||||||
|
# normalizer forwards. Other fields move into ``context``.
|
||||||
|
payload = {
|
||||||
|
"type": "permission_ask",
|
||||||
|
"action": {"tool": tool_name, "params": args or {}},
|
||||||
|
"context": {
|
||||||
|
"patterns": patterns,
|
||||||
|
"rules": [
|
||||||
|
{
|
||||||
|
"permission": r.permission,
|
||||||
|
"pattern": r.pattern,
|
||||||
|
"action": r.action,
|
||||||
|
}
|
||||||
|
for r in rules
|
||||||
|
],
|
||||||
|
# Rules of thumb for the frontend: surface the patterns
|
||||||
|
# the user can promote to "always" with a single reply.
|
||||||
|
"always": patterns,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# Open ``permission.asked`` + ``interrupt.raised`` OTel spans
|
||||||
|
# (no-op when OTel is disabled) so dashboards can correlate
|
||||||
|
# "we asked X" with "interrupt was actually delivered".
|
||||||
|
with (
|
||||||
|
ot.permission_asked_span(
|
||||||
|
permission=tool_name,
|
||||||
|
pattern=patterns[0] if patterns else None,
|
||||||
|
extra={"permission.patterns": list(patterns)},
|
||||||
|
),
|
||||||
|
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||||
|
):
|
||||||
|
decision = interrupt(payload)
|
||||||
|
if isinstance(decision, dict):
|
||||||
|
return decision
|
||||||
|
# Tolerate a plain string reply ("once", "always", "reject")
|
||||||
|
if isinstance(decision, str):
|
||||||
|
return {"decision_type": decision}
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||||
|
"""Promote ``always`` reply into runtime allow rules.
|
||||||
|
|
||||||
|
Persistence to ``agent_permission_rules`` is done by the
|
||||||
|
streaming layer (``stream_new_chat``) once it observes the
|
||||||
|
``always`` reply — the middleware just keeps an in-memory
|
||||||
|
copy so subsequent calls in the same stream see the rule.
|
||||||
|
"""
|
||||||
|
for pattern in patterns:
|
||||||
|
self._runtime_ruleset.rules.append(
|
||||||
|
Rule(permission=tool_name, pattern=pattern, action="allow")
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Synthesizing deny -> ToolMessage
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deny_message(
|
||||||
|
tool_call: dict[str, Any],
|
||||||
|
rule: Rule,
|
||||||
|
) -> ToolMessage:
|
||||||
|
err = StreamingError(
|
||||||
|
code="permission_denied",
|
||||||
|
retryable=False,
|
||||||
|
suggestion=(
|
||||||
|
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
|
||||||
|
f"blocked this call"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return ToolMessage(
|
||||||
|
content=(
|
||||||
|
f"Permission denied: rule {rule.permission}/{rule.pattern} "
|
||||||
|
f"blocked tool {tool_call.get('name')!r}."
|
||||||
|
),
|
||||||
|
tool_call_id=tool_call.get("id") or "",
|
||||||
|
name=tool_call.get("name"),
|
||||||
|
status="error",
|
||||||
|
additional_kwargs={"error": err.model_dump()},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# The hook: aafter_model
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _process(
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime # unused
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
last = messages[-1]
|
||||||
|
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||||
|
return None
|
||||||
|
|
||||||
|
deny_messages: list[ToolMessage] = []
|
||||||
|
kept_calls: list[dict[str, Any]] = []
|
||||||
|
any_change = False
|
||||||
|
|
||||||
|
for raw in last.tool_calls:
|
||||||
|
call = (
|
||||||
|
dict(raw)
|
||||||
|
if isinstance(raw, dict)
|
||||||
|
else {
|
||||||
|
"name": getattr(raw, "name", None),
|
||||||
|
"args": getattr(raw, "args", {}),
|
||||||
|
"id": getattr(raw, "id", None),
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
name = call.get("name") or ""
|
||||||
|
args = call.get("args") or {}
|
||||||
|
action, patterns, rules = self._evaluate(name, args)
|
||||||
|
|
||||||
|
if action == "deny":
|
||||||
|
# Find the deny rule for the suggestion text
|
||||||
|
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
|
||||||
|
deny_messages.append(self._deny_message(call, deny_rule))
|
||||||
|
any_change = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action == "ask":
|
||||||
|
decision = self._raise_interrupt(
|
||||||
|
tool_name=name, args=args, patterns=patterns, rules=rules
|
||||||
|
)
|
||||||
|
kind = str(decision.get("decision_type") or "reject").lower()
|
||||||
|
if kind == "once":
|
||||||
|
kept_calls.append(call)
|
||||||
|
elif kind == "always":
|
||||||
|
self._persist_always(name, patterns)
|
||||||
|
kept_calls.append(call)
|
||||||
|
elif kind == "reject":
|
||||||
|
feedback = decision.get("feedback")
|
||||||
|
if isinstance(feedback, str) and feedback.strip():
|
||||||
|
raise CorrectedError(feedback, tool=name)
|
||||||
|
raise RejectedError(
|
||||||
|
tool=name, pattern=patterns[0] if patterns else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown permission decision %r; treating as reject", kind
|
||||||
|
)
|
||||||
|
raise RejectedError(tool=name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# allow
|
||||||
|
kept_calls.append(call)
|
||||||
|
|
||||||
|
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||||
|
return None
|
||||||
|
|
||||||
|
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||||
|
result_messages: list[Any] = [updated]
|
||||||
|
if deny_messages:
|
||||||
|
result_messages.extend(deny_messages)
|
||||||
|
return {"messages": result_messages}
|
||||||
|
|
||||||
|
def after_model( # type: ignore[override]
|
||||||
|
self, state: AgentState, runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self._process(state, runtime)
|
||||||
|
|
||||||
|
async def aafter_model( # type: ignore[override]
|
||||||
|
self, state: AgentState, runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self._process(state, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PatternResolver",
|
||||||
|
"PermissionMiddleware",
|
||||||
|
]
|
||||||
257
surfsense_backend/app/agents/new_chat/middleware/retry_after.py
Normal file
257
surfsense_backend/app/agents/new_chat/middleware/retry_after.py
Normal file
|
|
@ -0,0 +1,257 @@
|
||||||
|
"""
|
||||||
|
RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing.
|
||||||
|
|
||||||
|
LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores
|
||||||
|
the ``Retry-After`` HTTP header — it just runs its own exponential backoff.
|
||||||
|
That wastes time when a provider has explicitly told us how long to wait.
|
||||||
|
This middleware honors the header (mirroring OpenCode's
|
||||||
|
``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE
|
||||||
|
event so the UI can show "rate-limited, retrying in Ns".
|
||||||
|
|
||||||
|
We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a
|
||||||
|
module-level ``calculate_delay`` inline (no overridable
|
||||||
|
``_calculate_delay`` hook), so this is a standalone implementation.
|
||||||
|
|
||||||
|
Behaviour:
|
||||||
|
- Extracts ``Retry-After`` / ``retry-after-ms`` from
|
||||||
|
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
|
||||||
|
exposing a similar shape).
|
||||||
|
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
|
||||||
|
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
|
||||||
|
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
|
||||||
|
the LangChain summarization fallback path) handles those instead.
|
||||||
|
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
|
||||||
|
so ``stream_new_chat`` can forward it to clients as an SSE event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Names of exception classes for which a retry would not help — context
|
||||||
|
# overflow needs compaction, auth needs human intervention, etc. Detected
|
||||||
|
# by class-name substring so we don't have to import LiteLLM/Anthropic
|
||||||
|
# here (which would tie this module to optional deps).
|
||||||
|
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
|
||||||
|
"ContextWindowExceeded",
|
||||||
|
"ContextOverflow",
|
||||||
|
"AuthenticationError",
|
||||||
|
"InvalidRequestError",
|
||||||
|
"PermissionDenied",
|
||||||
|
"InvalidApiKey",
|
||||||
|
"ContextLimit",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_non_retryable(exc: BaseException) -> bool:
|
||||||
|
name = type(exc).__name__
|
||||||
|
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
|
||||||
|
"""Return seconds-to-wait suggested by the provider, if any.
|
||||||
|
|
||||||
|
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
|
||||||
|
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
|
||||||
|
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
|
||||||
|
to a regex on the exception message for shapes like
|
||||||
|
``"Please retry after 30s"``.
|
||||||
|
"""
|
||||||
|
headers: dict[str, Any] | None = None
|
||||||
|
response = getattr(exc, "response", None)
|
||||||
|
if response is not None:
|
||||||
|
headers = getattr(response, "headers", None)
|
||||||
|
if headers is None:
|
||||||
|
headers = getattr(exc, "headers", None)
|
||||||
|
|
||||||
|
if isinstance(headers, dict):
|
||||||
|
# Normalize keys to lowercase for case-insensitive matching
|
||||||
|
norm = {str(k).lower(): v for k, v in headers.items()}
|
||||||
|
ms = norm.get("retry-after-ms")
|
||||||
|
if ms is not None:
|
||||||
|
try:
|
||||||
|
return float(ms) / 1000.0
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
seconds = norm.get("retry-after")
|
||||||
|
if seconds is not None:
|
||||||
|
try:
|
||||||
|
return float(seconds)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Last resort: scan the message for "retry after Xs" or "X seconds"
|
||||||
|
msg = str(exc)
|
||||||
|
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return float(match.group(1))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _exponential_delay(
|
||||||
|
attempt: int,
|
||||||
|
*,
|
||||||
|
initial_delay: float,
|
||||||
|
backoff_factor: float,
|
||||||
|
max_delay: float,
|
||||||
|
jitter: bool,
|
||||||
|
) -> float:
|
||||||
|
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
||||||
|
delay = (
|
||||||
|
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||||
|
)
|
||||||
|
delay = min(delay, max_delay)
|
||||||
|
if jitter and delay > 0:
|
||||||
|
delay *= 1 + random.uniform(-0.25, 0.25)
|
||||||
|
return max(delay, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
|
"""Retry middleware that honors provider-issued Retry-After hints.
|
||||||
|
|
||||||
|
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
|
||||||
|
when working with LiteLLM/Anthropic/OpenAI providers that surface
|
||||||
|
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
|
||||||
|
events so the UI can show a friendly "rate limited, retrying in Xs"
|
||||||
|
indicator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: Maximum retries after the initial attempt (default 3).
|
||||||
|
initial_delay: Initial backoff delay in seconds.
|
||||||
|
backoff_factor: Exponential growth factor for backoff.
|
||||||
|
max_delay: Cap on per-attempt delay in seconds.
|
||||||
|
jitter: Whether to add ±25% jitter.
|
||||||
|
retry_on: Optional callable that returns True for retryable
|
||||||
|
exceptions. The default retries everything except known
|
||||||
|
non-retryable classes (context overflow, auth, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
backoff_factor: float = 2.0,
|
||||||
|
max_delay: float = 60.0,
|
||||||
|
jitter: bool = True,
|
||||||
|
retry_on: Callable[[BaseException], bool] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.initial_delay = initial_delay
|
||||||
|
self.backoff_factor = backoff_factor
|
||||||
|
self.max_delay = max_delay
|
||||||
|
self.jitter = jitter
|
||||||
|
self._retry_on: Callable[[BaseException], bool] = retry_on or (
|
||||||
|
lambda exc: not _is_non_retryable(exc)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _should_retry(self, exc: BaseException) -> bool:
|
||||||
|
try:
|
||||||
|
return bool(self._retry_on(exc))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("retry_on callable raised; defaulting to False")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
|
||||||
|
backoff = _exponential_delay(
|
||||||
|
attempt,
|
||||||
|
initial_delay=self.initial_delay,
|
||||||
|
backoff_factor=self.backoff_factor,
|
||||||
|
max_delay=self.max_delay,
|
||||||
|
jitter=self.jitter,
|
||||||
|
)
|
||||||
|
header = _extract_retry_after_seconds(exc) or 0.0
|
||||||
|
return max(backoff, header)
|
||||||
|
|
||||||
|
def wrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||||
|
) -> ModelResponse[ResponseT] | AIMessage:
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return handler(request)
|
||||||
|
except Exception as exc:
|
||||||
|
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||||
|
raise
|
||||||
|
delay = self._delay_for_attempt(attempt, exc)
|
||||||
|
try:
|
||||||
|
dispatch_custom_event(
|
||||||
|
"surfsense.retrying",
|
||||||
|
{
|
||||||
|
"attempt": attempt + 1,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
"delay_ms": int(delay * 1000),
|
||||||
|
"reason": type(exc).__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"dispatch_custom_event failed; suppressed", exc_info=True
|
||||||
|
)
|
||||||
|
if delay > 0:
|
||||||
|
time.sleep(delay)
|
||||||
|
# Unreachable
|
||||||
|
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[
|
||||||
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||||
|
],
|
||||||
|
) -> ModelResponse[ResponseT] | AIMessage:
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
except Exception as exc:
|
||||||
|
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||||
|
raise
|
||||||
|
delay = self._delay_for_attempt(attempt, exc)
|
||||||
|
try:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"surfsense.retrying",
|
||||||
|
{
|
||||||
|
"attempt": attempt + 1,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
"delay_ms": int(delay * 1000),
|
||||||
|
"reason": type(exc).__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"adispatch_custom_event failed; suppressed", exc_info=True
|
||||||
|
)
|
||||||
|
if delay > 0:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RetryAfterMiddleware",
|
||||||
|
"_extract_retry_after_seconds",
|
||||||
|
"_is_non_retryable",
|
||||||
|
]
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
"""Safe wrapper around deepagents' SummarizationMiddleware.
|
|
||||||
|
|
||||||
Upstream issue
|
|
||||||
--------------
|
|
||||||
`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend`
|
|
||||||
(and its sync counterpart) call
|
|
||||||
``get_buffer_string(filtered_messages)`` before writing the evicted history
|
|
||||||
to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string``
|
|
||||||
accesses ``m.text`` which iterates ``self.content`` — this raises
|
|
||||||
``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage``
|
|
||||||
has ``content=None`` (common when a model returns *only* tool_calls, seen
|
|
||||||
frequently with Azure OpenAI ``gpt-5.x`` responses streamed through
|
|
||||||
LiteLLM).
|
|
||||||
|
|
||||||
The exception aborts the whole agent turn, so the user just sees "Error during
|
|
||||||
chat" with no assistant response.
|
|
||||||
|
|
||||||
Fix
|
|
||||||
---
|
|
||||||
We subclass ``SummarizationMiddleware`` and override
|
|
||||||
``_filter_summary_messages`` — the only call site that feeds messages into
|
|
||||||
``get_buffer_string`` — to return *copies* of messages whose ``content`` is
|
|
||||||
``None`` with ``content=""``. The originals flowing through the rest of the
|
|
||||||
agent state are untouched.
|
|
||||||
|
|
||||||
We also expose a drop-in ``create_safe_summarization_middleware`` factory
|
|
||||||
that mirrors ``deepagents.middleware.summarization.create_summarization_middleware``
|
|
||||||
but instantiates our safe subclass.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from deepagents.middleware.summarization import (
|
|
||||||
SummarizationMiddleware,
|
|
||||||
compute_summarization_defaults,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deepagents.backends.protocol import BACKEND_TYPES
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
|
||||||
from langchain_core.messages import AnyMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
|
||||||
"""Return ``msg`` with ``content`` coerced to a non-``None`` value.
|
|
||||||
|
|
||||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``;
|
|
||||||
when a provider streams back an ``AIMessage`` with only tool_calls and
|
|
||||||
no text, ``content`` can be ``None`` and the iteration explodes. We
|
|
||||||
replace ``None`` with an empty string so downstream consumers that only
|
|
||||||
care about text see an empty body.
|
|
||||||
|
|
||||||
The original message is left untouched — we return a copy via
|
|
||||||
pydantic's ``model_copy`` when available, otherwise we fall back to
|
|
||||||
re-setting the attribute on a shallow copy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if getattr(msg, "content", "not-missing") is not None:
|
|
||||||
return msg
|
|
||||||
|
|
||||||
try:
|
|
||||||
return msg.model_copy(update={"content": ""})
|
|
||||||
except AttributeError:
|
|
||||||
import copy
|
|
||||||
|
|
||||||
new_msg = copy.copy(msg)
|
|
||||||
try:
|
|
||||||
new_msg.content = ""
|
|
||||||
except Exception: # pragma: no cover - defensive
|
|
||||||
logger.debug(
|
|
||||||
"Could not sanitize content=None on message of type %s",
|
|
||||||
type(msg).__name__,
|
|
||||||
)
|
|
||||||
return msg
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
|
|
||||||
class SafeSummarizationMiddleware(SummarizationMiddleware):
|
|
||||||
"""`SummarizationMiddleware` that tolerates messages with ``content=None``.
|
|
||||||
|
|
||||||
Only ``_filter_summary_messages`` is overridden — this is the single
|
|
||||||
helper invoked by both the sync and async offload paths immediately
|
|
||||||
before ``get_buffer_string``. Normalising here means we get coverage
|
|
||||||
for both without having to copy the (long, rapidly-changing) offload
|
|
||||||
implementations from upstream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
|
||||||
filtered = super()._filter_summary_messages(messages)
|
|
||||||
return [_sanitize_message_content(m) for m in filtered]
|
|
||||||
|
|
||||||
|
|
||||||
def create_safe_summarization_middleware(
|
|
||||||
model: BaseChatModel,
|
|
||||||
backend: BACKEND_TYPES,
|
|
||||||
) -> SafeSummarizationMiddleware:
|
|
||||||
"""Drop-in replacement for ``create_summarization_middleware``.
|
|
||||||
|
|
||||||
Mirrors the defaults computed by ``deepagents`` but returns our
|
|
||||||
``SafeSummarizationMiddleware`` subclass so the
|
|
||||||
``content=None`` crash in ``get_buffer_string`` is avoided.
|
|
||||||
"""
|
|
||||||
|
|
||||||
defaults = compute_summarization_defaults(model)
|
|
||||||
return SafeSummarizationMiddleware(
|
|
||||||
model=model,
|
|
||||||
backend=backend,
|
|
||||||
trigger=defaults["trigger"],
|
|
||||||
keep=defaults["keep"],
|
|
||||||
trim_tokens_to_summarize=None,
|
|
||||||
truncate_args_settings=defaults["truncate_args_settings"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SafeSummarizationMiddleware",
|
|
||||||
"create_safe_summarization_middleware",
|
|
||||||
]
|
|
||||||
|
|
@ -0,0 +1,337 @@
|
||||||
|
"""Skills backends for SurfSense.
|
||||||
|
|
||||||
|
Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol`
|
||||||
|
subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`.
|
||||||
|
|
||||||
|
The middleware only needs four methods to load skills from a backend:
|
||||||
|
|
||||||
|
* ``ls_info`` / ``als_info`` — list directories under a source path.
|
||||||
|
* ``download_files`` / ``adownload_files`` — fetch ``SKILL.md`` bytes.
|
||||||
|
|
||||||
|
Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` …)
|
||||||
|
default to ``NotImplementedError`` from the base class. They are never reached
|
||||||
|
by the skills middleware because skill content is rendered into the system
|
||||||
|
prompt at agent build time, not edited at runtime.
|
||||||
|
|
||||||
|
Two backends are provided:
|
||||||
|
|
||||||
|
* :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from
|
||||||
|
``app/agents/new_chat/skills/builtin/``.
|
||||||
|
* :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over
|
||||||
|
:class:`KBPostgresBackend` that filters notes under the privileged folder
|
||||||
|
``/documents/_skills/``.
|
||||||
|
|
||||||
|
Both backends are intentionally read-only: skill authoring happens out of band
|
||||||
|
(via filesystem or a search-space-admin route), so we never expose
|
||||||
|
``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError``
|
||||||
|
gives a clean failure mode if anything tries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import replace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from deepagents.backends.composite import CompositeBackend
|
||||||
|
from deepagents.backends.protocol import (
|
||||||
|
BackendProtocol,
|
||||||
|
FileDownloadResponse,
|
||||||
|
FileInfo,
|
||||||
|
)
|
||||||
|
from deepagents.backends.state import StateBackend
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE.
|
||||||
|
_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
def _default_builtin_root() -> Path:
|
||||||
|
"""Return the absolute path to the bundled builtin skills directory.
|
||||||
|
|
||||||
|
Located at ``app/agents/new_chat/skills/builtin/`` relative to this module.
|
||||||
|
"""
|
||||||
|
return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinSkillsBackend(BackendProtocol):
|
||||||
|
"""Read-only disk-backed skills source.
|
||||||
|
|
||||||
|
Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk,
|
||||||
|
where each skill is its own subdirectory containing a ``SKILL.md`` file::
|
||||||
|
|
||||||
|
<root>/<skill-name>/SKILL.md
|
||||||
|
|
||||||
|
The middleware calls :meth:`als_info` with the source path and expects a
|
||||||
|
``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it
|
||||||
|
calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and
|
||||||
|
parses YAML frontmatter from the returned ``content`` bytes.
|
||||||
|
|
||||||
|
Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at
|
||||||
|
prefix ``/skills/builtin/`` means the middleware can issue paths like
|
||||||
|
``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to
|
||||||
|
``/kb-research/SKILL.md`` before forwarding here. We treat any leading
|
||||||
|
slash as anchoring at :attr:`root`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root: Path | str | None = None) -> None:
|
||||||
|
self.root: Path = Path(root).resolve() if root else _default_builtin_root()
|
||||||
|
if not self.root.exists():
|
||||||
|
logger.info(
|
||||||
|
"BuiltinSkillsBackend root %s does not exist; skills will be empty.",
|
||||||
|
self.root,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve(self, path: str) -> Path:
|
||||||
|
"""Resolve a virtual posix path under :attr:`root`, refusing escapes."""
|
||||||
|
bare = path.lstrip("/")
|
||||||
|
candidate = (self.root / bare).resolve() if bare else self.root
|
||||||
|
# Refuse symlink/.. traversal that escapes the root.
|
||||||
|
try:
|
||||||
|
candidate.relative_to(self.root)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"path {path!r} escapes builtin skills root") from exc
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]:
|
||||||
|
try:
|
||||||
|
target = self._resolve(path)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc)
|
||||||
|
return []
|
||||||
|
if not target.exists() or not target.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
infos: list[FileInfo] = []
|
||||||
|
# Build virtual paths anchored at "/" because CompositeBackend already
|
||||||
|
# stripped the route prefix before calling us.
|
||||||
|
target_virtual = (
|
||||||
|
"/"
|
||||||
|
if target == self.root
|
||||||
|
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
|
||||||
|
)
|
||||||
|
for child in sorted(target.iterdir()):
|
||||||
|
child_virtual = (
|
||||||
|
target_virtual.rstrip("/") + "/" + child.name
|
||||||
|
if target_virtual != "/"
|
||||||
|
else "/" + child.name
|
||||||
|
)
|
||||||
|
info: FileInfo = {
|
||||||
|
"path": child_virtual,
|
||||||
|
"is_dir": child.is_dir(),
|
||||||
|
}
|
||||||
|
if child.is_file():
|
||||||
|
with contextlib.suppress(OSError): # pragma: no cover - defensive
|
||||||
|
info["size"] = child.stat().st_size
|
||||||
|
infos.append(info)
|
||||||
|
return infos
|
||||||
|
|
||||||
|
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
responses: list[FileDownloadResponse] = []
|
||||||
|
for p in paths:
|
||||||
|
try:
|
||||||
|
target = self._resolve(p)
|
||||||
|
except ValueError:
|
||||||
|
responses.append(FileDownloadResponse(path=p, error="invalid_path"))
|
||||||
|
continue
|
||||||
|
if not target.exists():
|
||||||
|
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||||
|
continue
|
||||||
|
if target.is_dir():
|
||||||
|
responses.append(FileDownloadResponse(path=p, error="is_directory"))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# Hard cap to avoid loading rogue mega-files into memory.
|
||||||
|
size = target.stat().st_size
|
||||||
|
if size > _MAX_SKILL_FILE_SIZE:
|
||||||
|
logger.warning(
|
||||||
|
"Builtin skill file %s exceeds %d bytes; truncating.",
|
||||||
|
target,
|
||||||
|
_MAX_SKILL_FILE_SIZE,
|
||||||
|
)
|
||||||
|
with target.open("rb") as fh:
|
||||||
|
content = fh.read(_MAX_SKILL_FILE_SIZE)
|
||||||
|
else:
|
||||||
|
content = target.read_bytes()
|
||||||
|
except PermissionError:
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(path=p, error="permission_denied")
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except OSError as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("Builtin skill read failed %s: %s", target, exc)
|
||||||
|
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||||
|
continue
|
||||||
|
responses.append(FileDownloadResponse(path=p, content=content, error=None))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
class SearchSpaceSkillsBackend(BackendProtocol):
|
||||||
|
"""Read-only view of search-space-authored skills.
|
||||||
|
|
||||||
|
Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged
|
||||||
|
folder ``/documents/_skills/`` (configurable). The folder is intended to be
|
||||||
|
writable only by search-space admins; this backend never writes.
|
||||||
|
|
||||||
|
The skills middleware expects a layout like::
|
||||||
|
|
||||||
|
/<source_root>/<skill-name>/SKILL.md
|
||||||
|
|
||||||
|
But the KB stores documents like ``/documents/_skills/<name>/SKILL.md``.
|
||||||
|
We expose the inner namespace by remapping each path. When mounted under
|
||||||
|
:class:`CompositeBackend` at prefix ``/skills/space/`` the paths the
|
||||||
|
middleware sees become ``/skills/space/<name>/SKILL.md``; the composite
|
||||||
|
strips ``/skills/space/`` and hands us ``/<name>/SKILL.md``, which we
|
||||||
|
rewrite to ``/documents/_skills/<name>/SKILL.md`` before forwarding to the
|
||||||
|
KB.
|
||||||
|
|
||||||
|
No new database table is needed: the privileged folder convention is
|
||||||
|
enforced server-side outside of this class. We intentionally swallow any
|
||||||
|
write/edit attempts (the base class raises ``NotImplementedError``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_KB_ROOT: str = "/documents/_skills"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kb_backend: KBPostgresBackend,
|
||||||
|
*,
|
||||||
|
kb_root: str = DEFAULT_KB_ROOT,
|
||||||
|
) -> None:
|
||||||
|
self._kb = kb_backend
|
||||||
|
# Normalize trailing slash off so we can join cleanly.
|
||||||
|
self._kb_root = kb_root.rstrip("/") or "/"
|
||||||
|
|
||||||
|
def _to_kb(self, path: str) -> str:
|
||||||
|
"""Rewrite a virtual path into the underlying KB namespace."""
|
||||||
|
bare = path.lstrip("/")
|
||||||
|
if not bare:
|
||||||
|
return self._kb_root
|
||||||
|
return f"{self._kb_root}/{bare}"
|
||||||
|
|
||||||
|
def _from_kb(self, kb_path: str) -> str:
|
||||||
|
"""Rewrite a KB path back into our virtual namespace."""
|
||||||
|
if not kb_path.startswith(self._kb_root):
|
||||||
|
return kb_path # pragma: no cover - defensive
|
||||||
|
rel = kb_path[len(self._kb_root) :]
|
||||||
|
return rel if rel.startswith("/") else "/" + rel
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]:
|
||||||
|
# KBPostgresBackend exposes only the async API meaningfully; the sync
|
||||||
|
# path falls back to ``asyncio.to_thread(...)`` in the base class. We
|
||||||
|
# keep this stub to satisfy abstract resolution; the middleware calls
|
||||||
|
# ``als_info``.
|
||||||
|
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
|
||||||
|
|
||||||
|
async def als_info(self, path: str) -> list[FileInfo]:
|
||||||
|
kb_path = self._to_kb(path)
|
||||||
|
try:
|
||||||
|
infos = await self._kb.als_info(kb_path)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc)
|
||||||
|
return []
|
||||||
|
remapped: list[FileInfo] = []
|
||||||
|
for info in infos:
|
||||||
|
kb_p = info.get("path", "")
|
||||||
|
if not kb_p.startswith(self._kb_root):
|
||||||
|
continue
|
||||||
|
remapped.append({**info, "path": self._from_kb(kb_p)})
|
||||||
|
return remapped
|
||||||
|
|
||||||
|
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
|
||||||
|
|
||||||
|
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||||
|
kb_paths = [self._to_kb(p) for p in paths]
|
||||||
|
responses = await self._kb.adownload_files(kb_paths)
|
||||||
|
# Re-map response paths back to the virtual namespace so the middleware
|
||||||
|
# correlates them to the input list correctly.
|
||||||
|
remapped: list[FileDownloadResponse] = []
|
||||||
|
for original, resp in zip(paths, responses, strict=True):
|
||||||
|
remapped.append(replace(resp, path=original))
|
||||||
|
return remapped
|
||||||
|
|
||||||
|
|
||||||
|
SKILLS_BUILTIN_PREFIX = "/skills/builtin/"
|
||||||
|
SKILLS_SPACE_PREFIX = "/skills/space/"
|
||||||
|
|
||||||
|
|
||||||
|
def build_skills_backend_factory(
|
||||||
|
*,
|
||||||
|
builtin_root: Path | str | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> Callable[[ToolRuntime], BackendProtocol]:
|
||||||
|
"""Return a runtime-aware factory for the skills :class:`CompositeBackend`.
|
||||||
|
|
||||||
|
When ``search_space_id`` is provided the composite includes a
|
||||||
|
:class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh
|
||||||
|
per-runtime :class:`KBPostgresBackend`, mirroring how
|
||||||
|
:func:`build_backend_resolver` constructs the main filesystem backend.
|
||||||
|
|
||||||
|
When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit
|
||||||
|
tests) only the bundled :class:`BuiltinSkillsBackend` is exposed.
|
||||||
|
|
||||||
|
Returning a factory rather than a fixed instance is intentional: the
|
||||||
|
underlying KB backend depends on per-call ``ToolRuntime`` state
|
||||||
|
(``staged_dirs``, ``files`` cache, runtime config), so a single shared
|
||||||
|
instance cannot serve multiple concurrent agent runs.
|
||||||
|
"""
|
||||||
|
builtin = BuiltinSkillsBackend(builtin_root)
|
||||||
|
|
||||||
|
if search_space_id is None:
|
||||||
|
|
||||||
|
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
|
||||||
|
# Default StateBackend is intentionally inert: any path outside the
|
||||||
|
# ``/skills/builtin/`` route resolves to an empty per-runtime state
|
||||||
|
# so the SkillsMiddleware can iterate sources without raising.
|
||||||
|
return CompositeBackend(
|
||||||
|
default=StateBackend(runtime),
|
||||||
|
routes={SKILLS_BUILTIN_PREFIX: builtin},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _factory_builtin_only
|
||||||
|
|
||||||
|
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
|
||||||
|
# Imported lazily to avoid a hard dependency at module import time:
|
||||||
|
# ``KBPostgresBackend`` pulls in DB models, which are unnecessary for
|
||||||
|
# the unit-tested builtin path.
|
||||||
|
from app.agents.new_chat.middleware.kb_postgres_backend import (
|
||||||
|
KBPostgresBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
kb = KBPostgresBackend(search_space_id, runtime)
|
||||||
|
space = SearchSpaceSkillsBackend(kb)
|
||||||
|
return CompositeBackend(
|
||||||
|
default=StateBackend(runtime),
|
||||||
|
routes={
|
||||||
|
SKILLS_BUILTIN_PREFIX: builtin,
|
||||||
|
SKILLS_SPACE_PREFIX: space,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _factory_with_space
|
||||||
|
|
||||||
|
|
||||||
|
def default_skills_sources() -> list[str]:
|
||||||
|
"""Return the canonical source list for SkillsMiddleware (built-in then space)."""
|
||||||
|
return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SKILLS_BUILTIN_PREFIX",
|
||||||
|
"SKILLS_SPACE_PREFIX",
|
||||||
|
"BuiltinSkillsBackend",
|
||||||
|
"SearchSpaceSkillsBackend",
|
||||||
|
"build_skills_backend_factory",
|
||||||
|
"default_skills_sources",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,193 @@
|
||||||
|
"""
|
||||||
|
ToolCallNameRepairMiddleware — two-stage tool-name repair.
|
||||||
|
|
||||||
|
Operation:
|
||||||
|
1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in
|
||||||
|
the registry but ``name.lower()`` is, rewrite in place. Catches
|
||||||
|
models that emit ``Search`` instead of ``search``.
|
||||||
|
2. **Stage 2 — invalid fallback:** if still unmatched, rewrite the call
|
||||||
|
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
|
||||||
|
so the registered :func:`invalid_tool` returns the error to the model
|
||||||
|
for self-correction.
|
||||||
|
|
||||||
|
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358``
|
||||||
|
+ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent:
|
||||||
|
:class:`deepagents.middleware.PatchToolCallsMiddleware` patches
|
||||||
|
*dangling* tool calls (no matching ToolMessage) but does nothing about
|
||||||
|
wrong names, and the model framework's default behavior on an unknown
|
||||||
|
name is to crash the turn rather than route to a self-correction
|
||||||
|
fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
|
||||||
|
"""Normalize a tool call entry to a mutable dict."""
|
||||||
|
if isinstance(call, dict):
|
||||||
|
return dict(call)
|
||||||
|
return {
|
||||||
|
"name": getattr(call, "name", None),
|
||||||
|
"args": getattr(call, "args", {}),
|
||||||
|
"id": getattr(call, "id", None),
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallNameRepairMiddleware(
|
||||||
|
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||||
|
):
|
||||||
|
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registered_tool_names: Set of canonically-registered tool names.
|
||||||
|
``invalid`` should be in this set so the fallback dispatches.
|
||||||
|
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
|
||||||
|
fuzzy-match step that runs *between* lowercase and invalid.
|
||||||
|
Set to ``None`` to disable fuzzy matching (default in
|
||||||
|
OpenCode; we mirror that to avoid silent rewrites).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
registered_tool_names: set[str],
|
||||||
|
fuzzy_match_threshold: float | None = 0.85,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._registered = set(registered_tool_names)
|
||||||
|
self._registered_lower = {name.lower(): name for name in self._registered}
|
||||||
|
self._fuzzy_threshold = fuzzy_match_threshold
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
||||||
|
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
||||||
|
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
||||||
|
if isinstance(ctx_tools, set | frozenset):
|
||||||
|
return self._registered | set(ctx_tools)
|
||||||
|
if isinstance(ctx_tools, list | tuple):
|
||||||
|
return self._registered | set(ctx_tools)
|
||||||
|
return self._registered
|
||||||
|
|
||||||
|
def _repair_one(
|
||||||
|
self,
|
||||||
|
call: dict[str, Any],
|
||||||
|
registered: set[str],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
name = call.get("name")
|
||||||
|
if not isinstance(name, str):
|
||||||
|
return call
|
||||||
|
|
||||||
|
if name in registered:
|
||||||
|
return call
|
||||||
|
|
||||||
|
# Stage 1 — lowercase
|
||||||
|
lowered = name.lower()
|
||||||
|
if lowered in registered:
|
||||||
|
call["name"] = lowered
|
||||||
|
metadata = dict(call.get("response_metadata") or {})
|
||||||
|
metadata.setdefault("repair", "lowercase")
|
||||||
|
call["response_metadata"] = metadata
|
||||||
|
return call
|
||||||
|
|
||||||
|
# Optional fuzzy step (off by default — see class docstring)
|
||||||
|
if self._fuzzy_threshold is not None:
|
||||||
|
close = difflib.get_close_matches(
|
||||||
|
name, registered, n=1, cutoff=self._fuzzy_threshold
|
||||||
|
)
|
||||||
|
if close:
|
||||||
|
call["name"] = close[0]
|
||||||
|
metadata = dict(call.get("response_metadata") or {})
|
||||||
|
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
|
||||||
|
call["response_metadata"] = metadata
|
||||||
|
return call
|
||||||
|
|
||||||
|
# Stage 2 — invalid fallback
|
||||||
|
if INVALID_TOOL_NAME in registered:
|
||||||
|
original_args = call.get("args") or {}
|
||||||
|
error_msg = (
|
||||||
|
f"Tool name '{name}' is not registered. "
|
||||||
|
f"Original arguments were: {original_args!r}."
|
||||||
|
)
|
||||||
|
call["name"] = INVALID_TOOL_NAME
|
||||||
|
call["args"] = {"tool": name, "error": error_msg}
|
||||||
|
metadata = dict(call.get("response_metadata") or {})
|
||||||
|
metadata.setdefault("repair", f"invalid_fallback:{name}")
|
||||||
|
call["response_metadata"] = metadata
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Could not repair unknown tool call %r; 'invalid' tool not registered",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
return call
|
||||||
|
|
||||||
|
def _maybe_repair(
|
||||||
|
self,
|
||||||
|
message: AIMessage,
|
||||||
|
registered: set[str],
|
||||||
|
) -> AIMessage | None:
|
||||||
|
if not message.tool_calls:
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_calls: list[dict[str, Any]] = []
|
||||||
|
any_changed = False
|
||||||
|
for raw in message.tool_calls:
|
||||||
|
call = _coerce_existing_tool_call(raw)
|
||||||
|
before = (call.get("name"), call.get("args"))
|
||||||
|
repaired = self._repair_one(call, registered)
|
||||||
|
after = (repaired.get("name"), repaired.get("args"))
|
||||||
|
if before != after:
|
||||||
|
any_changed = True
|
||||||
|
new_calls.append(repaired)
|
||||||
|
|
||||||
|
if not any_changed:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return message.model_copy(update={"tool_calls": new_calls})
|
||||||
|
|
||||||
|
def after_model( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState[ResponseT],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
last = messages[-1]
|
||||||
|
if not isinstance(last, AIMessage):
|
||||||
|
return None
|
||||||
|
|
||||||
|
registered = self._registered_for_runtime(runtime)
|
||||||
|
repaired = self._maybe_repair(last, registered)
|
||||||
|
if repaired is None:
|
||||||
|
return None
|
||||||
|
return {"messages": [repaired]}
|
||||||
|
|
||||||
|
async def aafter_model( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState[ResponseT],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self.after_model(state, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ToolCallNameRepairMiddleware",
|
||||||
|
]
|
||||||
351
surfsense_backend/app/agents/new_chat/path_resolver.py
Normal file
351
surfsense_backend/app/agents/new_chat/path_resolver.py
Normal file
|
|
@ -0,0 +1,351 @@
|
||||||
|
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
|
||||||
|
|
||||||
|
This module is the single source of truth for mapping ``Document`` rows to
|
||||||
|
virtual paths under ``/documents/`` and back. It is used by:
|
||||||
|
|
||||||
|
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
|
||||||
|
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
|
||||||
|
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
|
||||||
|
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
|
||||||
|
|
||||||
|
Centralising the logic ensures that title-collision suffixes, folder paths,
|
||||||
|
and ``unique_identifier_hash`` lookups never drift between renders and
|
||||||
|
commits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import Document, DocumentType, Folder
|
||||||
|
from app.utils.document_converters import generate_unique_identifier_hash
|
||||||
|
|
||||||
|
DOCUMENTS_ROOT = "/documents"
|
||||||
|
"""Root virtual folder for all KB documents."""
|
||||||
|
|
||||||
|
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
|
||||||
|
_WHITESPACE_RUN = re.compile(r"\s+")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||||
|
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
|
||||||
|
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||||
|
name = _WHITESPACE_RUN.sub(" ", name)
|
||||||
|
if not name:
|
||||||
|
name = fallback
|
||||||
|
if len(name) > 180:
|
||||||
|
name = name[:180].rstrip()
|
||||||
|
if not name.lower().endswith(".xml"):
|
||||||
|
name = f"{name}.xml"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
|
||||||
|
"""Sanitize a single folder name into a path-safe segment."""
|
||||||
|
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||||
|
name = _WHITESPACE_RUN.sub(" ", name)
|
||||||
|
if not name:
|
||||||
|
return fallback
|
||||||
|
if len(name) > 180:
|
||||||
|
name = name[:180].rstrip()
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
|
||||||
|
if doc_id is None:
|
||||||
|
return filename
|
||||||
|
if not filename.lower().endswith(".xml"):
|
||||||
|
return f"{filename} ({doc_id}).xml"
|
||||||
|
stem = filename[:-4]
|
||||||
|
return f"{stem} ({doc_id}).xml"
|
||||||
|
|
||||||
|
|
||||||
|
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
|
||||||
|
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
|
||||||
|
|
||||||
|
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
|
||||||
|
"""
|
||||||
|
match = _SUFFIX_PATTERN.search(filename)
|
||||||
|
if match:
|
||||||
|
doc_id = int(match.group(1))
|
||||||
|
stem = filename[: match.start()]
|
||||||
|
return stem, doc_id
|
||||||
|
if filename.lower().endswith(".xml"):
|
||||||
|
return filename[:-4], None
|
||||||
|
return filename, None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathIndex:
|
||||||
|
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
|
||||||
|
|
||||||
|
Built once per call site so collision handling is deterministic and so
|
||||||
|
we don't perform N folder lookups per render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
folder_paths: dict[int, str] = field(default_factory=dict)
|
||||||
|
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
|
||||||
|
|
||||||
|
occupants: dict[str, int] = field(default_factory=dict)
|
||||||
|
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_folder_paths(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> dict[int, str]:
|
||||||
|
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||||
|
Folder.search_space_id == search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||||
|
cache: dict[int, str] = {}
|
||||||
|
|
||||||
|
def resolve(folder_id: int) -> str:
|
||||||
|
if folder_id in cache:
|
||||||
|
return cache[folder_id]
|
||||||
|
parts: list[str] = []
|
||||||
|
cursor: int | None = folder_id
|
||||||
|
visited: set[int] = set()
|
||||||
|
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||||
|
visited.add(cursor)
|
||||||
|
entry = by_id[cursor]
|
||||||
|
parts.append(safe_folder_segment(str(entry["name"])))
|
||||||
|
cursor = entry["parent_id"]
|
||||||
|
parts.reverse()
|
||||||
|
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||||
|
cache[folder_id] = path
|
||||||
|
return path
|
||||||
|
|
||||||
|
for folder_id in by_id:
|
||||||
|
resolve(folder_id)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
async def build_path_index(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
*,
|
||||||
|
populate_occupants: bool = True,
|
||||||
|
) -> PathIndex:
|
||||||
|
"""Build a :class:`PathIndex` for a search space.
|
||||||
|
|
||||||
|
``populate_occupants`` controls whether the occupancy map is pre-seeded
|
||||||
|
from existing ``Document`` rows. Most callers want this so that
|
||||||
|
:func:`doc_to_virtual_path` can detect collisions across the whole space;
|
||||||
|
the persistence middleware sets this to ``False`` when it is iterating to
|
||||||
|
decide where to place fresh documents.
|
||||||
|
"""
|
||||||
|
folder_paths = await _build_folder_paths(session, search_space_id)
|
||||||
|
occupants: dict[str, int] = {}
|
||||||
|
if populate_occupants:
|
||||||
|
rows = await session.execute(
|
||||||
|
select(Document.id, Document.title, Document.folder_id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for row in rows.all():
|
||||||
|
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
|
||||||
|
filename = safe_filename(str(row.title or "untitled"))
|
||||||
|
path = f"{base}/{filename}"
|
||||||
|
if path in occupants and occupants[path] != row.id:
|
||||||
|
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
|
||||||
|
occupants[path] = row.id
|
||||||
|
return PathIndex(folder_paths=folder_paths, occupants=occupants)
|
||||||
|
|
||||||
|
|
||||||
|
def doc_to_virtual_path(
|
||||||
|
*,
|
||||||
|
doc_id: int | None,
|
||||||
|
title: str,
|
||||||
|
folder_id: int | None,
|
||||||
|
index: PathIndex,
|
||||||
|
) -> str:
|
||||||
|
"""Return the canonical virtual path for a document.
|
||||||
|
|
||||||
|
Mutates ``index.occupants`` so subsequent calls see this assignment and
|
||||||
|
deterministically pick a different suffix for the next colliding doc.
|
||||||
|
"""
|
||||||
|
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||||
|
filename = safe_filename(str(title or "untitled"))
|
||||||
|
path = f"{base}/{filename}"
|
||||||
|
occupant = index.occupants.get(path)
|
||||||
|
if occupant is not None and occupant != doc_id:
|
||||||
|
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
|
||||||
|
if doc_id is not None:
|
||||||
|
index.occupants[path] = doc_id
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
async def virtual_path_to_doc(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
virtual_path: str,
|
||||||
|
) -> Document | None:
|
||||||
|
"""Resolve a virtual path back to a ``Document`` row.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
|
||||||
|
by SurfSense itself — every NOTE write goes through this hash).
|
||||||
|
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
|
||||||
|
try a direct id lookup constrained to the search space.
|
||||||
|
3. Title-from-basename + folder-resolution lookup as a last resort.
|
||||||
|
"""
|
||||||
|
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||||
|
return None
|
||||||
|
|
||||||
|
unique_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.unique_identifier_hash == unique_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
if document is not None:
|
||||||
|
return document
|
||||||
|
|
||||||
|
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
return None
|
||||||
|
parts = [p for p in rel.split("/") if p]
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
basename = parts[-1]
|
||||||
|
folder_parts = parts[:-1]
|
||||||
|
|
||||||
|
stem, suffix_doc_id = parse_doc_id_suffix(basename)
|
||||||
|
if suffix_doc_id is not None:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.id == suffix_doc_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
if document is not None:
|
||||||
|
return document
|
||||||
|
|
||||||
|
folder_id = await _resolve_folder_id(
|
||||||
|
session, search_space_id=search_space_id, folder_parts=folder_parts
|
||||||
|
)
|
||||||
|
title_candidates: list[str] = []
|
||||||
|
raw_title = stem
|
||||||
|
title_candidates.append(raw_title)
|
||||||
|
if raw_title.endswith(".xml"):
|
||||||
|
title_candidates.append(raw_title[:-4])
|
||||||
|
|
||||||
|
for candidate in dict.fromkeys(title_candidates):
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
query = select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.title == candidate,
|
||||||
|
)
|
||||||
|
if folder_id is None:
|
||||||
|
query = query.where(Document.folder_id.is_(None))
|
||||||
|
else:
|
||||||
|
query = query.where(Document.folder_id == folder_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
document = result.scalars().first()
|
||||||
|
if document is not None:
|
||||||
|
return document
|
||||||
|
|
||||||
|
# Fallback: title-as-string lookup misses when the real DB title contains
|
||||||
|
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
|
||||||
|
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
|
||||||
|
# The workspace tree shows the lossy filename, so the agent passes that
|
||||||
|
# filename back here. Scan all documents in the resolved folder and match
|
||||||
|
# by ``safe_filename(title)`` to recover the original document.
|
||||||
|
folder_scan = select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
if folder_id is None:
|
||||||
|
folder_scan = folder_scan.where(Document.folder_id.is_(None))
|
||||||
|
else:
|
||||||
|
folder_scan = folder_scan.where(Document.folder_id == folder_id)
|
||||||
|
result = await session.execute(folder_scan)
|
||||||
|
for candidate_doc in result.scalars().all():
|
||||||
|
encoded = safe_filename(str(candidate_doc.title or "untitled"))
|
||||||
|
if encoded == basename:
|
||||||
|
return candidate_doc
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_folder_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
folder_parts: list[str],
|
||||||
|
) -> int | None:
|
||||||
|
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
|
||||||
|
if not folder_parts:
|
||||||
|
return None
|
||||||
|
parent_id: int | None = None
|
||||||
|
for raw in folder_parts:
|
||||||
|
name = safe_folder_segment(raw)
|
||||||
|
query = select(Folder.id).where(
|
||||||
|
Folder.search_space_id == search_space_id,
|
||||||
|
Folder.name == name,
|
||||||
|
)
|
||||||
|
if parent_id is None:
|
||||||
|
query = query.where(Folder.parent_id.is_(None))
|
||||||
|
else:
|
||||||
|
query = query.where(Folder.parent_id == parent_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
row = result.first()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
parent_id = row[0]
|
||||||
|
return parent_id
|
||||||
|
|
||||||
|
|
||||||
|
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
|
||||||
|
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
|
||||||
|
|
||||||
|
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
|
||||||
|
disambiguation suffix stripped.
|
||||||
|
"""
|
||||||
|
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||||
|
return [], ""
|
||||||
|
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||||
|
if not rel:
|
||||||
|
return [], ""
|
||||||
|
parts = [p for p in rel.split("/") if p]
|
||||||
|
if not parts:
|
||||||
|
return [], ""
|
||||||
|
folder_parts = parts[:-1]
|
||||||
|
basename = parts[-1]
|
||||||
|
stem, _ = parse_doc_id_suffix(basename)
|
||||||
|
title = stem
|
||||||
|
if title.endswith(".xml"):
|
||||||
|
title = title[:-4]
|
||||||
|
return folder_parts, title
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DOCUMENTS_ROOT",
|
||||||
|
"PathIndex",
|
||||||
|
"build_path_index",
|
||||||
|
"doc_to_virtual_path",
|
||||||
|
"parse_doc_id_suffix",
|
||||||
|
"parse_documents_path",
|
||||||
|
"safe_filename",
|
||||||
|
"safe_folder_segment",
|
||||||
|
"virtual_path_to_doc",
|
||||||
|
]
|
||||||
203
surfsense_backend/app/agents/new_chat/permissions.py
Normal file
203
surfsense_backend/app/agents/new_chat/permissions.py
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
"""
|
||||||
|
Wildcard pattern matching + rule evaluation for the SurfSense permission system.
|
||||||
|
|
||||||
|
Ported from OpenCode's ``packages/opencode/src/permission/evaluate.ts`` and
|
||||||
|
``packages/opencode/src/util/wildcard.ts``. LangChain has no rule-based
|
||||||
|
permission evaluator, so we keep OpenCode's semantics intact:
|
||||||
|
|
||||||
|
- ``Wildcard.match`` matches both the ``permission`` and the ``pattern``
|
||||||
|
fields of a rule against the requested ``(permission, pattern)`` pair.
|
||||||
|
``*`` matches any segment, ``**`` matches across separators.
|
||||||
|
- The evaluator runs ``findLast`` over the **flattened** list of rules
|
||||||
|
from all rulesets — last matching rule wins.
|
||||||
|
- The default fallback is ``ask`` (NOT deny), matching OpenCode.
|
||||||
|
- Multi-pattern requests AND together: if ANY pattern resolves to
|
||||||
|
``deny``, the whole request is denied; if ANY needs ``ask``, an
|
||||||
|
interrupt is raised; only when all patterns ``allow`` does the
|
||||||
|
request proceed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
RuleAction = Literal["allow", "deny", "ask"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Rule:
|
||||||
|
"""A single permission rule.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
permission: A wildcard-matched permission identifier
|
||||||
|
(e.g. ``"edit"``, ``"linear_*"``, ``"mcp:*"``,
|
||||||
|
``"doom_loop"``). Anchored at start AND end of the input.
|
||||||
|
pattern: A wildcard-matched pattern over the request payload
|
||||||
|
(e.g. ``"/documents/secrets/**"``, ``"page_id=123"``,
|
||||||
|
``"*"``). Anchored at start AND end.
|
||||||
|
action: One of ``"allow"`` / ``"deny"`` / ``"ask"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
permission: str
|
||||||
|
pattern: str
|
||||||
|
action: RuleAction
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ruleset:
|
||||||
|
"""A list of rules with an associated origin used for debugging."""
|
||||||
|
|
||||||
|
rules: list[Rule] = field(default_factory=list)
|
||||||
|
origin: str = "unknown" # e.g. "defaults", "global", "space", "thread", "runtime"
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Wildcard matcher
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_GLOB_TOKEN = re.compile(r"\*\*|\*|[^*]+")
|
||||||
|
|
||||||
|
|
||||||
|
def _wildcard_to_regex(pattern: str) -> re.Pattern[str]:
|
||||||
|
"""Translate an opencode-style wildcard pattern to a compiled regex.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- ``**`` matches any sequence of any characters (including separators).
|
||||||
|
- ``*`` matches any sequence of characters that does **not** include
|
||||||
|
the path separator ``/`` — same as glob.
|
||||||
|
- All other characters match literally.
|
||||||
|
- The pattern is anchored at both ends (``^...$``).
|
||||||
|
"""
|
||||||
|
parts: list[str] = ["^"]
|
||||||
|
for token in _GLOB_TOKEN.findall(pattern):
|
||||||
|
if token == "**":
|
||||||
|
parts.append(r".*")
|
||||||
|
elif token == "*":
|
||||||
|
parts.append(r"[^/]*")
|
||||||
|
else:
|
||||||
|
parts.append(re.escape(token))
|
||||||
|
parts.append("$")
|
||||||
|
return re.compile("".join(parts))
|
||||||
|
|
||||||
|
|
||||||
|
_REGEX_CACHE: dict[str, re.Pattern[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def wildcard_match(value: str, pattern: str) -> bool:
|
||||||
|
"""Return True if ``value`` matches the wildcard ``pattern``.
|
||||||
|
|
||||||
|
Special case: a bare ``"*"`` pattern matches any value, including
|
||||||
|
those containing ``/`` separators. This mirrors opencode's
|
||||||
|
``Wildcard.match`` short-circuit and matches the convention that
|
||||||
|
``pattern="*"`` means "any pattern" in permission rules.
|
||||||
|
"""
|
||||||
|
if pattern == "*":
|
||||||
|
return True
|
||||||
|
compiled = _REGEX_CACHE.get(pattern)
|
||||||
|
if compiled is None:
|
||||||
|
compiled = _wildcard_to_regex(pattern)
|
||||||
|
_REGEX_CACHE[pattern] = compiled
|
||||||
|
return compiled.match(value) is not None
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Evaluator
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
permission: str,
|
||||||
|
pattern: str,
|
||||||
|
*rulesets: Ruleset | Iterable[Rule],
|
||||||
|
) -> Rule:
|
||||||
|
"""Find the last rule matching ``(permission, pattern)`` from ``rulesets``.
|
||||||
|
|
||||||
|
Mirrors opencode ``permission/evaluate.ts:9-15`` precisely:
|
||||||
|
- Flatten rulesets in argument order.
|
||||||
|
- Walk the flat list **in reverse**.
|
||||||
|
- First reverse-match wins (i.e. the last specified rule wins).
|
||||||
|
- When no rule matches, default to ``Rule(permission, "*", "ask")``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
permission: The permission identifier being requested
|
||||||
|
(e.g. tool name, ``"edit"``, ``"doom_loop"``).
|
||||||
|
pattern: The request-specific pattern (e.g. file path,
|
||||||
|
primary arg value). Use ``"*"`` when no specific pattern
|
||||||
|
applies.
|
||||||
|
*rulesets: Layered rulesets, applied earliest to latest. Later
|
||||||
|
rulesets override earlier ones.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The matched :class:`Rule`, or the default ask fallback.
|
||||||
|
"""
|
||||||
|
flat: list[Rule] = []
|
||||||
|
for rs in rulesets:
|
||||||
|
if isinstance(rs, Ruleset):
|
||||||
|
flat.extend(rs.rules)
|
||||||
|
else:
|
||||||
|
flat.extend(rs)
|
||||||
|
|
||||||
|
for rule in reversed(flat):
|
||||||
|
if wildcard_match(permission, rule.permission) and wildcard_match(
|
||||||
|
pattern, rule.pattern
|
||||||
|
):
|
||||||
|
return rule
|
||||||
|
|
||||||
|
return Rule(permission=permission, pattern="*", action="ask")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_many(
|
||||||
|
permission: str,
|
||||||
|
patterns: Iterable[str],
|
||||||
|
*rulesets: Ruleset | Iterable[Rule],
|
||||||
|
) -> list[Rule]:
|
||||||
|
"""Evaluate ``permission`` against each of ``patterns`` (multi-pattern AND).
|
||||||
|
|
||||||
|
Returns the list of resolved rules in the same order as ``patterns``.
|
||||||
|
The caller is responsible for combining the results — opencode-style
|
||||||
|
multi-pattern AND collapses ``deny`` first, then ``ask``, then
|
||||||
|
``allow``.
|
||||||
|
"""
|
||||||
|
return [evaluate(permission, p, *rulesets) for p in patterns]
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_action(rules: Iterable[Rule]) -> RuleAction:
|
||||||
|
"""Collapse a list of per-pattern rules into one action.
|
||||||
|
|
||||||
|
Order:
|
||||||
|
1. If any rule is ``deny`` -> ``deny``.
|
||||||
|
2. Else if any rule is ``ask`` -> ``ask``.
|
||||||
|
3. Else if at least one rule is ``allow`` -> ``allow``.
|
||||||
|
4. Else (empty input) -> ``ask`` (safe default mirroring ``evaluate``).
|
||||||
|
|
||||||
|
Mirrors opencode's behavior in ``permission/index.ts:180-272``.
|
||||||
|
"""
|
||||||
|
saw_ask = False
|
||||||
|
saw_allow = False
|
||||||
|
for rule in rules:
|
||||||
|
if rule.action == "deny":
|
||||||
|
return "deny"
|
||||||
|
if rule.action == "ask":
|
||||||
|
saw_ask = True
|
||||||
|
elif rule.action == "allow":
|
||||||
|
saw_allow = True
|
||||||
|
if saw_ask:
|
||||||
|
return "ask"
|
||||||
|
if saw_allow:
|
||||||
|
return "allow"
|
||||||
|
return "ask"
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Rule",
|
||||||
|
"RuleAction",
|
||||||
|
"Ruleset",
|
||||||
|
"aggregate_action",
|
||||||
|
"evaluate",
|
||||||
|
"evaluate_many",
|
||||||
|
"wildcard_match",
|
||||||
|
]
|
||||||
158
surfsense_backend/app/agents/new_chat/plugin_loader.py
Normal file
158
surfsense_backend/app/agents/new_chat/plugin_loader.py
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
"""Entry-point based plugin loader for SurfSense agent middleware.
|
||||||
|
|
||||||
|
LangChain's :class:`AgentMiddleware` ABC already covers the practical
|
||||||
|
surface most plugins need (``before_agent`` / ``before_model`` /
|
||||||
|
``wrap_tool_call`` / their async counterparts), so a SurfSense-specific
|
||||||
|
plugin protocol would be redundant. We just need a way to discover and
|
||||||
|
admit third-party middleware safely.
|
||||||
|
|
||||||
|
A plugin is therefore just an installable Python package that registers a
|
||||||
|
factory callable under the ``surfsense.plugins`` entry-point group:
|
||||||
|
|
||||||
|
.. code-block:: toml
|
||||||
|
|
||||||
|
# in a plugin package's pyproject.toml
|
||||||
|
[project.entry-points."surfsense.plugins"]
|
||||||
|
year_substituter = "my_plugin:make_middleware"
|
||||||
|
|
||||||
|
The factory has the signature ``Callable[[PluginContext], AgentMiddleware]``.
|
||||||
|
It receives a small, sanitized :class:`PluginContext` with the IDs and the
|
||||||
|
LLM the plugin is allowed to talk to — and **never** raw secrets, DB
|
||||||
|
sessions, or other connectors.
|
||||||
|
|
||||||
|
## Trust model
|
||||||
|
|
||||||
|
Plugins are loaded **only if** their entry-point ``name`` appears in
|
||||||
|
``allowed_plugins`` (admin-controlled, sourced from
|
||||||
|
``global_llm_config.yaml`` or :func:`load_allowed_plugin_names_from_env`).
|
||||||
|
There is **no env-driven auto-load**. A plugin failure is logged and
|
||||||
|
isolated; it does not break agent construction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from importlib.metadata import entry_points
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
PLUGIN_ENTRY_POINT_GROUP = "surfsense.plugins"
|
||||||
|
|
||||||
|
|
||||||
|
class PluginContext(dict):
|
||||||
|
"""Sanitized DI bag handed to each plugin factory.
|
||||||
|
|
||||||
|
Backed by ``dict`` so plugins can inspect the keys they care about
|
||||||
|
without coupling to a concrete dataclass shape. Required keys:
|
||||||
|
|
||||||
|
* ``search_space_id`` (int)
|
||||||
|
* ``user_id`` (str | None)
|
||||||
|
* ``thread_visibility`` (:class:`app.db.ChatVisibility`)
|
||||||
|
* ``llm`` (:class:`langchain_core.language_models.BaseChatModel`)
|
||||||
|
|
||||||
|
The context **never** carries DB sessions, raw secrets, or other
|
||||||
|
connectors. If a future plugin genuinely needs DB access, that
|
||||||
|
integration goes through a rate-limited service interface, not
|
||||||
|
through this bag.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_visibility: ChatVisibility,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> PluginContext:
|
||||||
|
return cls(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_visibility=thread_visibility,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugin_middlewares(
|
||||||
|
ctx: PluginContext,
|
||||||
|
allowed_plugin_names: Iterable[str],
|
||||||
|
) -> list[AgentMiddleware]:
|
||||||
|
"""Discover, allowlist-filter, and instantiate plugin middleware.
|
||||||
|
|
||||||
|
For each entry-point in :data:`PLUGIN_ENTRY_POINT_GROUP` whose name is
|
||||||
|
in ``allowed_plugin_names``, load the factory and call it with ``ctx``.
|
||||||
|
The factory's return value must be an :class:`AgentMiddleware` instance;
|
||||||
|
anything else is logged and skipped.
|
||||||
|
|
||||||
|
Errors are isolated — a plugin that raises during ``ep.load()`` or
|
||||||
|
factory invocation is logged at ``ERROR`` and ignored. Agent
|
||||||
|
construction continues with whatever plugins did succeed.
|
||||||
|
"""
|
||||||
|
allowed = {name for name in allowed_plugin_names if name}
|
||||||
|
if not allowed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
out: list[AgentMiddleware] = []
|
||||||
|
try:
|
||||||
|
eps = entry_points(group=PLUGIN_ENTRY_POINT_GROUP)
|
||||||
|
except Exception: # pragma: no cover - defensive (entry_points is robust)
|
||||||
|
logger.exception("Failed to enumerate plugin entry points")
|
||||||
|
return []
|
||||||
|
|
||||||
|
for ep in eps:
|
||||||
|
if ep.name not in allowed:
|
||||||
|
logger.info("Skipping non-allowlisted plugin %s", ep.name)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
factory = ep.load()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load plugin %s", ep.name)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
mw = factory(ctx)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Plugin %s factory raised", ep.name)
|
||||||
|
continue
|
||||||
|
if not isinstance(mw, AgentMiddleware):
|
||||||
|
logger.warning(
|
||||||
|
"Plugin %s returned %s, expected AgentMiddleware; skipping",
|
||||||
|
ep.name,
|
||||||
|
type(mw).__name__,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
out.append(mw)
|
||||||
|
logger.info("Loaded plugin %s as %s", ep.name, type(mw).__name__)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load_allowed_plugin_names_from_env() -> set[str]:
|
||||||
|
"""Read ``SURFSENSE_ALLOWED_PLUGINS`` (comma-separated) into a set.
|
||||||
|
|
||||||
|
Provided as a thin convenience for deployments that don't surface plugins
|
||||||
|
through ``global_llm_config.yaml`` yet. Whitespace is stripped and empty
|
||||||
|
entries are dropped.
|
||||||
|
"""
|
||||||
|
raw = os.environ.get("SURFSENSE_ALLOWED_PLUGINS", "").strip()
|
||||||
|
if not raw:
|
||||||
|
return set()
|
||||||
|
return {token.strip() for token in raw.split(",") if token.strip()}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PLUGIN_ENTRY_POINT_GROUP",
|
||||||
|
"PluginContext",
|
||||||
|
"load_allowed_plugin_names_from_env",
|
||||||
|
"load_plugin_middlewares",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
"""Reference plugins bundled with SurfSense.
|
||||||
|
|
||||||
|
These plugins are intentionally small and demonstrative. They are NOT
|
||||||
|
auto-loaded — they ship as examples that a deployment can opt into via
|
||||||
|
``global_llm_config.yaml`` or ``SURFSENSE_ALLOWED_PLUGINS``.
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
"""Reference plugin: substitute ``{{year}}`` in tool descriptions.
|
||||||
|
|
||||||
|
Demonstrates the :meth:`AgentMiddleware.awrap_tool_call` hook -- the
|
||||||
|
plugin sees every tool invocation and can rewrite the request *or* the
|
||||||
|
result. This particular plugin is read-only and only transforms the
|
||||||
|
*description* the user might see in error messages (no request
|
||||||
|
mutation).
|
||||||
|
|
||||||
|
The plugin is built as a factory function so the entry-point loader can
|
||||||
|
inject :class:`PluginContext` (containing the agent's LLM, search-space
|
||||||
|
ID, etc.). The factory signature
|
||||||
|
``Callable[[PluginContext], AgentMiddleware]`` is the only contract --
|
||||||
|
SurfSense doesn't define a custom plugin protocol on top of LangChain's
|
||||||
|
:class:`AgentMiddleware`.
|
||||||
|
|
||||||
|
Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't
|
||||||
|
need this -- it's already on the import path)::
|
||||||
|
|
||||||
|
[project.entry-points."surfsense.plugins"]
|
||||||
|
year_substituter = "app.agents.new_chat.plugins.year_substituter:make_middleware"
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||||
|
from langchain.agents.middleware.types import ToolCallRequest
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from app.agents.new_chat.plugin_loader import PluginContext
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _YearSubstituterMiddleware(AgentMiddleware):
|
||||||
|
"""Replace ``{{year}}`` in the result text with the current UTC year."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(self, year: int | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._year = str(year if year is not None else datetime.now(UTC).year)
|
||||||
|
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||||
|
) -> ToolMessage | Command[Any]:
|
||||||
|
result = await handler(request)
|
||||||
|
try:
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(result, ToolMessage)
|
||||||
|
and isinstance(result.content, str)
|
||||||
|
and "{{year}}" in result.content
|
||||||
|
):
|
||||||
|
new_text = result.content.replace("{{year}}", self._year)
|
||||||
|
result = ToolMessage(
|
||||||
|
content=new_text,
|
||||||
|
tool_call_id=result.tool_call_id,
|
||||||
|
id=result.id,
|
||||||
|
name=result.name,
|
||||||
|
status=result.status,
|
||||||
|
artifact=result.artifact,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logger.exception("year_substituter plugin failed; passing original result")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def make_middleware(ctx: PluginContext) -> AgentMiddleware:
|
||||||
|
"""Plugin factory used by :func:`load_plugin_middlewares`."""
|
||||||
|
# Plugin is intentionally small so it has no state to threading-protect
|
||||||
|
# and ignores ``ctx`` beyond demonstrating that the loader passes it in.
|
||||||
|
_ = ctx
|
||||||
|
return _YearSubstituterMiddleware()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["make_middleware"]
|
||||||
166
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
166
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
|
|
@ -0,0 +1,166 @@
|
||||||
|
"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||||
|
|
||||||
|
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||||
|
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||||
|
gate always failed) with LiteLLM's universal caching mechanism.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||||
|
performs automatically when ``cache_control_injection_points`` is set):
|
||||||
|
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||||
|
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||||
|
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||||
|
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||||
|
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||||
|
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
We inject **two** breakpoints per request:
|
||||||
|
|
||||||
|
- ``role: system`` — pins the SurfSense system prompt (provider variant,
|
||||||
|
citation rules, tool catalog, KB tree, skills metadata) into the cache.
|
||||||
|
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||||
|
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||||
|
N+1 still reads turn N's cache up to the shared prefix.
|
||||||
|
|
||||||
|
For OpenAI-family configs we additionally pass:
|
||||||
|
|
||||||
|
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||||
|
raises hit rate by sending requests with a shared prefix to the same
|
||||||
|
backend.
|
||||||
|
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||||
|
5-10 min in-memory cache.
|
||||||
|
|
||||||
|
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||||
|
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||||
|
provider doesn't recognise is auto-stripped at the provider transformer
|
||||||
|
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||||
|
``prompt_cache_key`` etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Two-breakpoint policy: system + latest message. See module docstring for
|
||||||
|
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
|
||||||
|
# use 2 here, leaving headroom for Phase-2 tool caching.
|
||||||
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
|
{"location": "message", "role": "system"},
|
||||||
|
{"location": "message", "index": -1},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
|
||||||
|
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
|
||||||
|
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
|
||||||
|
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
|
||||||
|
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
|
||||||
|
# MINIMAX), so we can't infer family from the litellm prefix alone.
|
||||||
|
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
|
||||||
|
|
||||||
|
|
||||||
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
|
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||||
|
|
||||||
|
Importing ``app.services.llm_router_service`` at module-load time would
|
||||||
|
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||||
|
Class-name comparison is sufficient since the class is defined in a
|
||||||
|
single place.
|
||||||
|
"""
|
||||||
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
||||||
|
"""Whether the config targets an OpenAI-style prompt-cache surface.
|
||||||
|
|
||||||
|
Strict — only returns True when the user explicitly chose OPENAI,
|
||||||
|
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
|
||||||
|
``YAMLConfig``. Auto-mode and custom providers return False because
|
||||||
|
we can't statically know the destination.
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||||
|
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||||
|
|
||||||
|
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||||
|
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||||
|
``model_kwargs`` attribute (caller should treat as no-op).
|
||||||
|
"""
|
||||||
|
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||||
|
if isinstance(model_kwargs, dict):
|
||||||
|
return model_kwargs
|
||||||
|
try:
|
||||||
|
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
refreshed = getattr(llm, "model_kwargs", None)
|
||||||
|
return refreshed if isinstance(refreshed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_litellm_prompt_caching(
|
||||||
|
llm: BaseChatModel,
|
||||||
|
*,
|
||||||
|
agent_config: AgentConfig | None = None,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||||
|
|
||||||
|
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||||
|
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||||
|
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||||
|
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||||
|
in our custom ``ChatLiteLLMRouter``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||||
|
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||||
|
behaviour. When omitted (or auto-mode), only the universal
|
||||||
|
``cache_control_injection_points`` are set.
|
||||||
|
thread_id: Optional thread id used to construct a per-thread
|
||||||
|
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||||
|
works without it (server-side automatic), but the key improves
|
||||||
|
backend routing affinity and therefore hit rate.
|
||||||
|
"""
|
||||||
|
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||||
|
if model_kwargs is None:
|
||||||
|
logger.debug(
|
||||||
|
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||||
|
type(llm).__name__,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if "cache_control_injection_points" not in model_kwargs:
|
||||||
|
model_kwargs["cache_control_injection_points"] = [
|
||||||
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
|
]
|
||||||
|
|
||||||
|
# OpenAI-family extras only when we statically know the destination is
|
||||||
|
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
|
||||||
|
# so we can't safely set OpenAI-only kwargs there (drop_params would
|
||||||
|
# strip them but it's wasteful to set them in the first place).
|
||||||
|
if _is_router_llm(llm):
|
||||||
|
return
|
||||||
|
if not _is_openai_family_config(agent_config):
|
||||||
|
return
|
||||||
|
|
||||||
|
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
|
||||||
|
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||||
|
if "prompt_cache_retention" not in model_kwargs:
|
||||||
|
model_kwargs["prompt_cache_retention"] = "24h"
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""SurfSense agent prompt fragments.
|
||||||
|
|
||||||
|
The prompt is composed at runtime by :mod:`composer` from the markdown
|
||||||
|
fragments under ``base/``, ``providers/``, ``tools/``, ``examples/``, and
|
||||||
|
``routing/``. ``system_prompt.py`` is now a thin wrapper that delegates
|
||||||
|
to :func:`composer.compose_system_prompt`.
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||||
|
|
||||||
|
Today's date (UTC): {resolved_today}
|
||||||
|
|
||||||
|
When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math.
|
||||||
|
|
||||||
|
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
|
||||||
|
|
||||||
|
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
|
||||||
|
|
||||||
|
Today's date (UTC): {resolved_today}
|
||||||
|
|
||||||
|
When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math.
|
||||||
|
|
||||||
|
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
<citation_instructions>
|
||||||
|
IMPORTANT: Citations are DISABLED for this configuration.
|
||||||
|
|
||||||
|
DO NOT include any citations in your responses. Specifically:
|
||||||
|
1. Do NOT use the [citation:chunk_id] format anywhere in your response.
|
||||||
|
2. Do NOT reference document IDs, chunk IDs, or source IDs.
|
||||||
|
3. Simply provide the information naturally without any citation markers.
|
||||||
|
4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly.
|
||||||
|
|
||||||
|
When answering questions based on documents from the knowledge base:
|
||||||
|
- Present the information directly and confidently
|
||||||
|
- Do not mention that information comes from specific documents or chunks
|
||||||
|
- Integrate facts naturally into your response without attribution markers
|
||||||
|
|
||||||
|
Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation.
|
||||||
|
</citation_instructions>
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
<citation_instructions>
|
||||||
|
CRITICAL CITATION REQUIREMENTS:
|
||||||
|
|
||||||
|
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
|
||||||
|
2. Make sure ALL factual statements from the documents have proper citations.
|
||||||
|
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
|
||||||
|
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
|
||||||
|
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
|
||||||
|
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
|
||||||
|
7. Do not return citations as clickable links.
|
||||||
|
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
|
||||||
|
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
|
||||||
|
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
|
||||||
|
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
|
||||||
|
|
||||||
|
<document_structure_example>
|
||||||
|
The documents you receive are structured like this:
|
||||||
|
|
||||||
|
**Knowledge base documents (numeric chunk IDs):**
|
||||||
|
<document>
|
||||||
|
<document_metadata>
|
||||||
|
<document_id>42</document_id>
|
||||||
|
<document_type>GITHUB_CONNECTOR</document_type>
|
||||||
|
<title><![CDATA[Some repo / file / issue title]]></title>
|
||||||
|
<url><![CDATA[https://example.com]]></url>
|
||||||
|
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
|
||||||
|
</document_metadata>
|
||||||
|
|
||||||
|
<document_content>
|
||||||
|
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
|
||||||
|
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
|
||||||
|
</document_content>
|
||||||
|
</document>
|
||||||
|
|
||||||
|
**Web search results (URL chunk IDs):**
|
||||||
|
<document>
|
||||||
|
<document_metadata>
|
||||||
|
<document_type>WEB_SEARCH</document_type>
|
||||||
|
<title><![CDATA[Some web search result]]></title>
|
||||||
|
<url><![CDATA[https://example.com/article]]></url>
|
||||||
|
</document_metadata>
|
||||||
|
|
||||||
|
<document_content>
|
||||||
|
<chunk id='https://example.com/article'><![CDATA[Content from web search...]]></chunk>
|
||||||
|
</document_content>
|
||||||
|
</document>
|
||||||
|
|
||||||
|
IMPORTANT: You MUST cite using the EXACT chunk ids from the `<chunk id='...'>` tags.
|
||||||
|
- For knowledge base documents, chunk ids are numeric (e.g. 123, 124) or prefixed (e.g. doc-45).
|
||||||
|
- For live web search results, chunk ids are URLs (e.g. https://example.com/article).
|
||||||
|
Do NOT cite document_id. Always use the chunk id.
|
||||||
|
</document_structure_example>
|
||||||
|
|
||||||
|
<citation_format>
|
||||||
|
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
|
||||||
|
- Citations should appear at the end of the sentence containing the information they support
|
||||||
|
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||||
|
- No need to return references section. Just citations in answer.
|
||||||
|
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
|
||||||
|
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
|
||||||
|
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
|
||||||
|
- Copy the EXACT chunk id from the XML - if it says `<chunk id='doc-123'>`, use [citation:doc-123]
|
||||||
|
- If the chunk id is a URL like `<chunk id='https://example.com/page'>`, use [citation:https://example.com/page]
|
||||||
|
</citation_format>
|
||||||
|
|
||||||
|
<citation_examples>
|
||||||
|
CORRECT citation formats:
|
||||||
|
- [citation:5] (numeric chunk ID from knowledge base)
|
||||||
|
- [citation:doc-123] (for Surfsense documentation chunks)
|
||||||
|
- [citation:https://example.com/article] (URL chunk ID from web search results)
|
||||||
|
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations)
|
||||||
|
|
||||||
|
INCORRECT citation formats (DO NOT use):
|
||||||
|
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
|
||||||
|
- Using parentheses around brackets: ([citation:5])
|
||||||
|
- Using hyperlinked text: [link to source 5](https://example.com)
|
||||||
|
- Using footnote style: ... library¹
|
||||||
|
- Making up source IDs when source_id is unknown
|
||||||
|
- Using old IEEE format: [1], [2], [3]
|
||||||
|
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
|
||||||
|
</citation_examples>
|
||||||
|
|
||||||
|
<citation_output_example>
|
||||||
|
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
|
||||||
|
|
||||||
|
According to web search results, the key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:https://docs.python.org/3/library/asyncio.html]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
|
||||||
|
|
||||||
|
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
|
||||||
|
</citation_output_example>
|
||||||
|
</citation_instructions>
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
<knowledge_base_only_policy>
|
||||||
|
CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
|
- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs.
|
||||||
|
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission.
|
||||||
|
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
|
||||||
|
1. Inform the user that you could not find relevant information in their knowledge base.
|
||||||
|
2. Ask the user: "Would you like me to answer from my general knowledge instead?"
|
||||||
|
3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes.
|
||||||
|
- This policy does NOT apply to:
|
||||||
|
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
|
||||||
|
* Formatting, summarization, or analysis of content already present in the conversation
|
||||||
|
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||||
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
|
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||||
|
</knowledge_base_only_policy>
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
<knowledge_base_only_policy>
|
||||||
|
CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
|
- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs.
|
||||||
|
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission.
|
||||||
|
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
|
||||||
|
1. Inform the team that you could not find relevant information in the shared knowledge base.
|
||||||
|
2. Ask: "Would you like me to answer from my general knowledge instead?"
|
||||||
|
3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes.
|
||||||
|
- This policy does NOT apply to:
|
||||||
|
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
|
||||||
|
* Formatting, summarization, or analysis of content already present in the conversation
|
||||||
|
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||||
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
|
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||||
|
</knowledge_base_only_policy>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<memory_protocol>
|
||||||
|
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||||
|
reveal durable facts about the user (role, interests, preferences, projects,
|
||||||
|
background, or standing instructions)? If yes, you MUST call update_memory
|
||||||
|
alongside your normal response — do not defer this to a later turn.
|
||||||
|
</memory_protocol>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<memory_protocol>
|
||||||
|
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||||
|
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||||
|
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
||||||
|
do not defer this to a later turn.
|
||||||
|
</memory_protocol>
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
<parameter_resolution>
|
||||||
|
Some service tools require identifiers or context you do not have (account IDs,
|
||||||
|
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||||
|
IDs or technical identifiers — they cannot memorise them.
|
||||||
|
|
||||||
|
Instead, follow this discovery pattern:
|
||||||
|
1. Call a listing/discovery tool to find available options.
|
||||||
|
2. ONE result → use it silently, no question to the user.
|
||||||
|
3. MULTIPLE results → present the options by their display names and let the
|
||||||
|
user choose. Never show raw UUIDs — always use friendly names.
|
||||||
|
|
||||||
|
Discovery tools by level:
|
||||||
|
- Which account/workspace? → get_connected_accounts("<service>")
|
||||||
|
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||||
|
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||||
|
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||||
|
- Which channel? → slack_search_channels
|
||||||
|
- Which base? → list_bases
|
||||||
|
- Which table? → list_tables_for_base (after resolving baseId)
|
||||||
|
- Which task? → clickup_search
|
||||||
|
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||||
|
|
||||||
|
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||||
|
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||||
|
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||||
|
If there is only one option at each step, use it silently. If multiple, present
|
||||||
|
friendly names.
|
||||||
|
|
||||||
|
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||||
|
base → list_tables_for_base → pick table → list_records_for_table.
|
||||||
|
|
||||||
|
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||||
|
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||||
|
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||||
|
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||||
|
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||||
|
the full list of accounts with their connector IDs and display names.
|
||||||
|
When only one account is connected, tools have their normal unprefixed names.
|
||||||
|
</parameter_resolution>
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
<tool_routing>
|
||||||
|
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||||
|
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||||
|
say "I don't see it in the knowledge base" or ask the user if they want you to check.
|
||||||
|
Ignore any knowledge base results for these services.
|
||||||
|
|
||||||
|
When to use which tool:
|
||||||
|
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||||
|
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||||
|
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||||
|
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||||
|
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||||
|
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||||
|
- Real-time public web data → call web_search
|
||||||
|
- Reading a specific webpage → call scrape_webpage
|
||||||
|
</tool_routing>
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
<tool_routing>
|
||||||
|
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||||
|
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||||
|
say "I don't see it in the knowledge base" or ask if they want you to check.
|
||||||
|
Ignore any knowledge base results for these services.
|
||||||
|
|
||||||
|
When to use which tool:
|
||||||
|
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||||
|
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||||
|
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||||
|
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||||
|
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||||
|
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||||
|
- Real-time public web data → call web_search
|
||||||
|
- Reading a specific webpage → call scrape_webpage
|
||||||
|
</tool_routing>
|
||||||
405
surfsense_backend/app/agents/new_chat/prompts/composer.py
Normal file
405
surfsense_backend/app/agents/new_chat/prompts/composer.py
Normal file
|
|
@ -0,0 +1,405 @@
|
||||||
|
"""
|
||||||
|
Prompt composer for the SurfSense ``new_chat`` agent.
|
||||||
|
|
||||||
|
This module assembles the agent's system prompt from the markdown fragments
|
||||||
|
under :mod:`app.agents.new_chat.prompts`. It replaces the monolithic
|
||||||
|
``system_prompt.py`` with a clean, fragment-based composition:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
prompts/
|
||||||
|
base/ # agent identity, KB policy, tool routing, …
|
||||||
|
providers/ # provider-specific tweaks (anthropic, gpt5, …)
|
||||||
|
tools/ # one ``<name>.md`` per tool
|
||||||
|
examples/ # one ``<name>.md`` per tool with call examples
|
||||||
|
routing/ # connector-specific routing notes (linear, slack, …)
|
||||||
|
|
||||||
|
The model-family dispatch step (see :func:`detect_provider_variant`)
|
||||||
|
mirrors OpenCode's ``packages/opencode/src/session/system.ts`` — different
|
||||||
|
model families respond best to differently-styled prompts (Claude likes
|
||||||
|
XML/narrative, GPT-5 wants channel-aware pragmatic, Codex needs
|
||||||
|
terse/file:line, Gemini wants formal numbered steps, etc.). LangChain's
|
||||||
|
``dynamic_prompt`` helper supports per-call prompt swaps but ships no
|
||||||
|
out-of-the-box family classifier, so we keep our own.
|
||||||
|
|
||||||
|
Backwards compatibility
|
||||||
|
=======================
|
||||||
|
|
||||||
|
``system_prompt.py`` re-exports :func:`compose_system_prompt` and wraps it
|
||||||
|
in functions with the same signatures as the legacy
|
||||||
|
``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` so
|
||||||
|
existing call sites do not change.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from importlib import resources
|
||||||
|
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Provider variant detection
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# String literal alias for the supported provider-specific prompt variants.
|
||||||
|
# When adding a new variant, also drop a matching ``providers/<variant>.md``
|
||||||
|
# file in this package and (if appropriate) extend the regex matchers below.
|
||||||
|
#
|
||||||
|
# Stylistic clusters: each variant is a focused style nudge, NOT a full
|
||||||
|
# system prompt — the main prompt is already assembled from base/ +
|
||||||
|
# tools/ + routing/. The clustering itself (which models map to which
|
||||||
|
# style) follows OpenCode's ``system.ts`` family table; see the module
|
||||||
|
# docstring for credits.
|
||||||
|
ProviderVariant = str
|
||||||
|
# Known values:
|
||||||
|
# "anthropic" — Claude family (XML-friendly, narrative todos)
|
||||||
|
# "openai_reasoning" — GPT-5 / o-series (channel-aware pragmatic)
|
||||||
|
# "openai_classic" — GPT-4 family (autonomous persistence)
|
||||||
|
# "openai_codex" — gpt-*-codex (code-purist, terse, file:line refs)
|
||||||
|
# "google" — Gemini (formal, <3-line, numbered workflow)
|
||||||
|
# "kimi" — Moonshot Kimi-K* (action-bias, parallel tools)
|
||||||
|
# "grok" — xAI Grok (extreme-terse, one-word ok)
|
||||||
|
# "deepseek" — DeepSeek V3 / R1 (terse, R1-aware reasoning)
|
||||||
|
# "default" — fallback, no provider-specific block emitted
|
||||||
|
|
||||||
|
# IMPORTANT: order of evaluation matters in :func:`detect_provider_variant`.
|
||||||
|
# More specific patterns must come first (e.g. ``codex`` before
|
||||||
|
# ``openai_reasoning`` because codex model ids contain ``gpt``).
|
||||||
|
|
||||||
|
_OPENAI_CODEX_RE = re.compile(
|
||||||
|
r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE
|
||||||
|
)
|
||||||
|
_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE)
|
||||||
|
_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE)
|
||||||
|
_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE)
|
||||||
|
_GOOGLE_RE = re.compile(r"\bgemini\b", re.IGNORECASE)
|
||||||
|
_KIMI_RE = re.compile(r"\b(kimi[-\d.]*|moonshot)\b", re.IGNORECASE)
|
||||||
|
_GROK_RE = re.compile(r"\bgrok\b", re.IGNORECASE)
|
||||||
|
_DEEPSEEK_RE = re.compile(r"\bdeepseek\b", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_provider_variant(model_name: str | None) -> ProviderVariant:
|
||||||
|
"""Pick a provider-specific prompt variant from a model id string.
|
||||||
|
|
||||||
|
Heuristic match on the model id; returns ``"default"`` when nothing
|
||||||
|
matches so the composer can fall back to the empty placeholder file.
|
||||||
|
|
||||||
|
Order is significant: more-specific patterns are tried first so
|
||||||
|
``gpt-5-codex`` routes to ``"openai_codex"`` rather than
|
||||||
|
``"openai_reasoning"`` — same dispatch order as OpenCode's
|
||||||
|
``packages/opencode/src/session/system.ts``.
|
||||||
|
"""
|
||||||
|
if not model_name:
|
||||||
|
return "default"
|
||||||
|
name = model_name.strip()
|
||||||
|
if _OPENAI_CODEX_RE.search(name):
|
||||||
|
return "openai_codex"
|
||||||
|
if _OPENAI_REASONING_RE.search(name):
|
||||||
|
return "openai_reasoning"
|
||||||
|
if _OPENAI_CLASSIC_RE.search(name):
|
||||||
|
return "openai_classic"
|
||||||
|
if _ANTHROPIC_RE.search(name):
|
||||||
|
return "anthropic"
|
||||||
|
if _GOOGLE_RE.search(name):
|
||||||
|
return "google"
|
||||||
|
if _KIMI_RE.search(name):
|
||||||
|
return "kimi"
|
||||||
|
if _GROK_RE.search(name):
|
||||||
|
return "grok"
|
||||||
|
if _DEEPSEEK_RE.search(name):
|
||||||
|
return "deepseek"
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Fragment loading
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_PROMPTS_PACKAGE = "app.agents.new_chat.prompts"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_fragment(subpath: str) -> str:
|
||||||
|
"""Read a fragment file from the ``prompts/`` resource tree.
|
||||||
|
|
||||||
|
Returns the raw contents stripped of any single trailing newline so
|
||||||
|
composition can append explicit separators without compounding blank
|
||||||
|
lines. Missing files return an empty string so optional fragments
|
||||||
|
(e.g. provider hints) act as no-ops.
|
||||||
|
"""
|
||||||
|
parts = subpath.split("/")
|
||||||
|
try:
|
||||||
|
ref = resources.files(_PROMPTS_PACKAGE).joinpath(*parts)
|
||||||
|
if not ref.is_file():
|
||||||
|
return ""
|
||||||
|
text = ref.read_text(encoding="utf-8")
|
||||||
|
except (FileNotFoundError, ModuleNotFoundError):
|
||||||
|
return ""
|
||||||
|
if text.endswith("\n"):
|
||||||
|
text = text[:-1]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Tool ordering + memory variant resolution
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# Ordered for reading flow: fundamentals first, then artifact generators,
|
||||||
|
# then memory at the end (mirrors the legacy ``_ALL_TOOL_NAMES_ORDERED``).
|
||||||
|
ALL_TOOL_NAMES_ORDERED: tuple[str, ...] = (
|
||||||
|
"search_surfsense_docs",
|
||||||
|
"web_search",
|
||||||
|
"generate_podcast",
|
||||||
|
"generate_video_presentation",
|
||||||
|
"generate_report",
|
||||||
|
"generate_resume",
|
||||||
|
"generate_image",
|
||||||
|
"scrape_webpage",
|
||||||
|
"update_memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_MEMORY_VARIANT_TOOLS: frozenset[str] = frozenset({"update_memory"})
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_fragment_path(tool_name: str, variant: str) -> str:
|
||||||
|
"""Resolve a tool's instruction fragment path.
|
||||||
|
|
||||||
|
Tools listed in :data:`_MEMORY_VARIANT_TOOLS` switch on the conversation
|
||||||
|
visibility and load ``tools/<name>_<variant>.md``; everything else
|
||||||
|
falls back to ``tools/<name>.md``.
|
||||||
|
"""
|
||||||
|
if tool_name in _MEMORY_VARIANT_TOOLS:
|
||||||
|
return f"tools/{tool_name}_{variant}.md"
|
||||||
|
return f"tools/{tool_name}.md"
|
||||||
|
|
||||||
|
|
||||||
|
def _example_fragment_path(tool_name: str, variant: str) -> str:
|
||||||
|
if tool_name in _MEMORY_VARIANT_TOOLS:
|
||||||
|
return f"examples/{tool_name}_{variant}.md"
|
||||||
|
return f"examples/{tool_name}.md"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_label(tool_name: str) -> str:
|
||||||
|
return tool_name.replace("_", " ").title()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Section builders
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_instructions(
|
||||||
|
*,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
resolved_today: str,
|
||||||
|
) -> str:
|
||||||
|
"""Reconstruct the legacy ``<system_instruction>`` block from fragments."""
|
||||||
|
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
|
||||||
|
|
||||||
|
sections = [
|
||||||
|
_read_fragment(f"base/agent_{variant}.md"),
|
||||||
|
_read_fragment(f"base/kb_only_policy_{variant}.md"),
|
||||||
|
_read_fragment(f"base/tool_routing_{variant}.md"),
|
||||||
|
_read_fragment("base/parameter_resolution.md"),
|
||||||
|
_read_fragment(f"base/memory_protocol_{variant}.md"),
|
||||||
|
]
|
||||||
|
body = "\n\n".join(s for s in sections if s)
|
||||||
|
block = f"\n<system_instruction>\n{body}\n\n</system_instruction>\n"
|
||||||
|
return block.format(resolved_today=resolved_today)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mcp_routing_block(
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Emit the ``<mcp_tool_routing>`` block when at least one MCP server is wired."""
|
||||||
|
if not mcp_connector_tools:
|
||||||
|
return ""
|
||||||
|
lines: list[str] = [
|
||||||
|
"\n<mcp_tool_routing>",
|
||||||
|
"You also have direct tools from these user-connected MCP servers.",
|
||||||
|
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
for server_name, tool_names in mcp_connector_tools.items():
|
||||||
|
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||||
|
lines.append("</mcp_tool_routing>\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tools_section(
|
||||||
|
*,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
enabled_tool_names: set[str] | None,
|
||||||
|
disabled_tool_names: set[str] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Reconstruct the ``<tools>`` block + ``<tool_call_examples>`` block."""
|
||||||
|
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
preamble = _read_fragment("tools/_preamble.md")
|
||||||
|
if preamble:
|
||||||
|
parts.append(preamble + "\n")
|
||||||
|
|
||||||
|
examples: list[str] = []
|
||||||
|
|
||||||
|
for tool_name in ALL_TOOL_NAMES_ORDERED:
|
||||||
|
if enabled_tool_names is not None and tool_name not in enabled_tool_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
instruction = _read_fragment(_tool_fragment_path(tool_name, variant))
|
||||||
|
if instruction:
|
||||||
|
parts.append(instruction + "\n")
|
||||||
|
|
||||||
|
example = _read_fragment(_example_fragment_path(tool_name, variant))
|
||||||
|
if example:
|
||||||
|
examples.append(example + "\n")
|
||||||
|
|
||||||
|
known_disabled = (
|
||||||
|
set(disabled_tool_names) & set(ALL_TOOL_NAMES_ORDERED)
|
||||||
|
if disabled_tool_names
|
||||||
|
else set()
|
||||||
|
)
|
||||||
|
if known_disabled:
|
||||||
|
disabled_list = ", ".join(
|
||||||
|
_format_tool_label(n) for n in ALL_TOOL_NAMES_ORDERED if n in known_disabled
|
||||||
|
)
|
||||||
|
parts.append(
|
||||||
|
"\n"
|
||||||
|
"DISABLED TOOLS (by user):\n"
|
||||||
|
f"The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}.\n"
|
||||||
|
"You do NOT have access to these tools and MUST NOT claim you can use them.\n"
|
||||||
|
"If the user asks about a capability provided by a disabled tool, let them know the relevant tool\n"
|
||||||
|
"is currently disabled and they can re-enable it.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
parts.append("\n</tools>\n")
|
||||||
|
|
||||||
|
if examples:
|
||||||
|
parts.append("<tool_call_examples>")
|
||||||
|
parts.extend(examples)
|
||||||
|
parts.append("</tool_call_examples>\n")
|
||||||
|
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_provider_block(provider_variant: ProviderVariant) -> str:
|
||||||
|
"""Optional provider-tuned hints. Empty for ``"default"``."""
|
||||||
|
if not provider_variant or provider_variant == "default":
|
||||||
|
return ""
|
||||||
|
text = _read_fragment(f"providers/{provider_variant}.md")
|
||||||
|
return f"\n{text}\n" if text else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_routing_block(connector_routing: Iterable[str] | None) -> str:
|
||||||
|
if not connector_routing:
|
||||||
|
return ""
|
||||||
|
fragments: list[str] = []
|
||||||
|
for name in connector_routing:
|
||||||
|
text = _read_fragment(f"routing/{name}.md")
|
||||||
|
if text:
|
||||||
|
fragments.append(text)
|
||||||
|
if not fragments:
|
||||||
|
return ""
|
||||||
|
return "\n" + "\n\n".join(fragments) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_citation_block(citations_enabled: bool) -> str:
|
||||||
|
fragment = (
|
||||||
|
_read_fragment("base/citations_on.md")
|
||||||
|
if citations_enabled
|
||||||
|
else _read_fragment("base/citations_off.md")
|
||||||
|
)
|
||||||
|
return f"\n{fragment}\n" if fragment else ""
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def compose_system_prompt(
|
||||||
|
*,
|
||||||
|
today: datetime | None = None,
|
||||||
|
thread_visibility: ChatVisibility | None = None,
|
||||||
|
enabled_tool_names: set[str] | None = None,
|
||||||
|
disabled_tool_names: set[str] | None = None,
|
||||||
|
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||||
|
custom_system_instructions: str | None = None,
|
||||||
|
use_default_system_instructions: bool = True,
|
||||||
|
citations_enabled: bool = True,
|
||||||
|
provider_variant: ProviderVariant | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
connector_routing: Iterable[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Assemble the SurfSense system prompt from disk fragments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
today: Optional clock injection for tests.
|
||||||
|
thread_visibility: Private vs shared (team) — drives memory wording
|
||||||
|
and a few base block variants.
|
||||||
|
enabled_tool_names: When provided, only these tools' instructions
|
||||||
|
are included; ``None`` keeps the legacy "include everything"
|
||||||
|
behavior.
|
||||||
|
disabled_tool_names: User-disabled tools (note appended to prompt).
|
||||||
|
mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject
|
||||||
|
an explicit MCP routing block.
|
||||||
|
custom_system_instructions: Free-form instructions that override
|
||||||
|
the default ``<system_instruction>`` block (legacy support
|
||||||
|
for ``NewLLMConfig.system_instructions``).
|
||||||
|
use_default_system_instructions: When ``custom_system_instructions``
|
||||||
|
is empty/None, fall back to defaults (legacy semantics).
|
||||||
|
citations_enabled: Include ``citations_on.md`` (true) or
|
||||||
|
``citations_off.md`` (false).
|
||||||
|
provider_variant: Explicit provider variant override
|
||||||
|
(``"anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default"``).
|
||||||
|
When ``None``, falls back to :func:`detect_provider_variant`
|
||||||
|
on ``model_name``.
|
||||||
|
model_name: Used to auto-detect ``provider_variant`` when not
|
||||||
|
provided explicitly.
|
||||||
|
connector_routing: Optional list of routing fragment names
|
||||||
|
(``["linear", "slack", ...]``) to include from
|
||||||
|
``prompts/routing/``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The fully composed system prompt string.
|
||||||
|
"""
|
||||||
|
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||||
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
||||||
|
if custom_system_instructions and custom_system_instructions.strip():
|
||||||
|
sys_block = custom_system_instructions.format(resolved_today=resolved_today)
|
||||||
|
elif use_default_system_instructions:
|
||||||
|
sys_block = _build_system_instructions(
|
||||||
|
visibility=visibility, resolved_today=resolved_today
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sys_block = ""
|
||||||
|
|
||||||
|
sys_block += _build_mcp_routing_block(mcp_connector_tools)
|
||||||
|
|
||||||
|
if provider_variant is None:
|
||||||
|
provider_variant = detect_provider_variant(model_name)
|
||||||
|
sys_block += _build_provider_block(provider_variant)
|
||||||
|
sys_block += _build_routing_block(connector_routing)
|
||||||
|
|
||||||
|
tools_block = _build_tools_section(
|
||||||
|
visibility=visibility,
|
||||||
|
enabled_tool_names=enabled_tool_names,
|
||||||
|
disabled_tool_names=disabled_tool_names,
|
||||||
|
)
|
||||||
|
citation_block = _build_citation_block(citations_enabled)
|
||||||
|
|
||||||
|
return sys_block + tools_block + citation_block
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ALL_TOOL_NAMES_ORDERED",
|
||||||
|
"ProviderVariant",
|
||||||
|
"compose_system_prompt",
|
||||||
|
"detect_provider_variant",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
|
||||||
|
- User: "Generate an image of a cat"
|
||||||
|
- Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||||
|
- The generated image will automatically be displayed in the chat.
|
||||||
|
- User: "Draw me a logo for a coffee shop called Bean Dream"
|
||||||
|
- Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||||
|
- The generated image will automatically be displayed in the chat.
|
||||||
|
- User: "Show me this image: https://example.com/image.png"
|
||||||
|
- Simply include it in your response using markdown: ``
|
||||||
|
- User uploads an image file and asks: "What is this image about?"
|
||||||
|
- The user's uploaded image is already visible in the chat.
|
||||||
|
- Simply analyze the image content and respond directly.
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||||
|
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||||
|
- User: "Create a podcast summary of this conversation"
|
||||||
|
- Call: `generate_podcast(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
|
||||||
|
- User: "Make a podcast about quantum computing"
|
||||||
|
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")`
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
|
||||||
|
- User: "Generate a report about AI trends"
|
||||||
|
- Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")`
|
||||||
|
- WHY: Has creation verb "generate" → call the tool. No prior discussion → use kb_search.
|
||||||
|
- User: "Write a research report from this conversation"
|
||||||
|
- Call: `generate_report(topic="Research Report", source_strategy="conversation", source_content="Complete conversation summary:\n\n...", report_style="deep_research")`
|
||||||
|
- WHY: Has creation verb "write" → call the tool. Conversation has the content → use source_strategy="conversation".
|
||||||
|
- User: (after a report on Climate Change was generated) "Add a section about carbon capture technologies"
|
||||||
|
- Call: `generate_report(topic="Climate Crisis: Causes, Impacts, and Solutions", source_strategy="conversation", source_content="[summary of conversation context if any]", parent_report_id=<previous_report_id>, user_instructions="Add a new section about carbon capture technologies")`
|
||||||
|
- WHY: Has modification verb "add" + specific deliverable target → call the tool with parent_report_id.
|
||||||
|
- User: (after a report was generated) "What else could we add to have more depth?"
|
||||||
|
- Do NOT call generate_report. Answer in chat with suggestions.
|
||||||
|
- WHY: No creation/modification verb directed at producing a deliverable.
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
|
||||||
|
- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..."
|
||||||
|
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)`
|
||||||
|
- WHY: Has creation verb "build" + resume → call the tool.
|
||||||
|
- User: "Create my CV with this info: [experience, education, skills]"
|
||||||
|
- Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)`
|
||||||
|
- User: "Build me a resume" (and there is a resume/CV document in the conversation context)
|
||||||
|
- Extract the FULL content from the document in context, then call:
|
||||||
|
`generate_resume(user_info="Name: John Doe\nEmail: john@example.com\n\nExperience:\n- Senior Engineer at Acme Corp (2020-2024)\n Led team of 5...\n\nEducation:\n- BS Computer Science, MIT (2016-2020)\n\nSkills: Python, TypeScript, AWS...", max_pages=1)`
|
||||||
|
- WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents.
|
||||||
|
- User: (after resume generated) "Change my title to Senior Engineer"
|
||||||
|
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>, max_pages=1)`
|
||||||
|
- WHY: Modification verb "change" + refers to existing resume → set parent_report_id.
|
||||||
|
- User: (after resume generated) "Make this 2 pages and expand projects"
|
||||||
|
- Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=<previous_report_id>, max_pages=2)`
|
||||||
|
- WHY: Explicit page increase request → set max_pages to 2.
|
||||||
|
- User: "How should I structure my resume?"
|
||||||
|
- Do NOT call generate_resume. Answer in chat with advice.
|
||||||
|
- WHY: No creation/modification verb.
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
- User: "Give me a presentation about AI trends based on what we discussed"
|
||||||
|
- First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")`
|
||||||
|
- User: "Create slides summarizing this conversation"
|
||||||
|
- Call: `generate_video_presentation(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
|
||||||
|
- User: "Make a video presentation about quantum computing"
|
||||||
|
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")`
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
|
||||||
|
- User: "Check out https://dev.to/some-article"
|
||||||
|
- Call: `scrape_webpage(url="https://dev.to/some-article")`
|
||||||
|
- Respond with a structured analysis — key points, takeaways.
|
||||||
|
- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends"
|
||||||
|
- Call: `scrape_webpage(url="https://example.com/blog/ai-trends")`
|
||||||
|
- Respond with a thorough summary using headings and bullet points.
|
||||||
|
- User: (after discussing https://example.com/stats) "Can you get the live data from that page?"
|
||||||
|
- Call: `scrape_webpage(url="https://example.com/stats")`
|
||||||
|
- IMPORTANT: Always attempt scraping first. Never refuse before trying the tool.
|
||||||
|
- User: "https://example.com/blog/weekend-recipes"
|
||||||
|
- Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")`
|
||||||
|
- When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content.
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
|
||||||
|
- User: "How do I install SurfSense?"
|
||||||
|
- Call: `search_surfsense_docs(query="installation setup")`
|
||||||
|
- User: "What connectors does SurfSense support?"
|
||||||
|
- Call: `search_surfsense_docs(query="available connectors integrations")`
|
||||||
|
- User: "How do I set up the Notion connector?"
|
||||||
|
- Call: `search_surfsense_docs(query="Notion connector setup configuration")`
|
||||||
|
- User: "How do I use Docker to run SurfSense?"
|
||||||
|
- Call: `search_surfsense_docs(query="Docker installation setup")`
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
|
||||||
|
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
||||||
|
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
|
||||||
|
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n")
|
||||||
|
- User: "Remember that I prefer concise answers over detailed explanations"
|
||||||
|
- Durable preference. Merge with existing memory, add a new heading:
|
||||||
|
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n")
|
||||||
|
- User: "I actually moved to Tokyo last month"
|
||||||
|
- Updated fact, date prefix reflects when recorded:
|
||||||
|
update_memory(updated_memory="## Interests & background\n...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...")
|
||||||
|
- User: "I'm a freelance photographer working on a nature documentary"
|
||||||
|
- Durable background info under a fitting heading:
|
||||||
|
update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n")
|
||||||
|
- User: "Always respond in bullet points"
|
||||||
|
- Standing instruction:
|
||||||
|
update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n")
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||||
|
- Durable team decision:
|
||||||
|
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...")
|
||||||
|
- User: "Our office is in downtown Seattle, 5th floor"
|
||||||
|
- Durable team fact:
|
||||||
|
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...")
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
- User: "What's the current USD to INR exchange rate?"
|
||||||
|
- Call: `web_search(query="current USD to INR exchange rate")`
|
||||||
|
- Then answer using the returned web results with citations.
|
||||||
|
- User: "What's the latest news about AI?"
|
||||||
|
- Call: `web_search(query="latest AI news today")`
|
||||||
|
- User: "What's the weather in New York?"
|
||||||
|
- Call: `web_search(query="weather New York today")`
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
<provider_hints>
|
||||||
|
You are running on an Anthropic Claude model.
|
||||||
|
|
||||||
|
Structured reasoning:
|
||||||
|
- Use XML tags liberally to organise intermediate reasoning when a task is non-trivial. `<thinking>...</thinking>` blocks are encouraged before tool calls or before producing a complex final answer.
|
||||||
|
- For multi-step requests, briefly outline a plan inside a `<plan>` block before issuing the first tool call.
|
||||||
|
|
||||||
|
Professional objectivity:
|
||||||
|
- Prioritise technical accuracy over validating the user's beliefs. Provide direct, factual guidance without unnecessary superlatives, praise, or emotional validation.
|
||||||
|
- When uncertain, investigate (search the KB, fetch the page) rather than confirming the user's assumption.
|
||||||
|
- Disagree with the user when the evidence warrants it; respectful correction beats false agreement.
|
||||||
|
|
||||||
|
Task management:
|
||||||
|
- For tasks with 3+ distinct steps use the todo / planning tool aggressively. Mark items in_progress before starting, completed immediately when finished — do not batch completions.
|
||||||
|
- Narrate progress through the todo list itself, not through chatty status lines.
|
||||||
|
|
||||||
|
Tool calls:
|
||||||
|
- Run independent tool calls in parallel within one response. Sequence them only when a later call genuinely needs an earlier one's output.
|
||||||
|
- Never chain bash-like commands with `;` or `&&` to "narrate" — use prose between tool calls instead.
|
||||||
|
</provider_hints>
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
<provider_hints>
|
||||||
|
You are running on a DeepSeek model (DeepSeek-V3 chat / DeepSeek-R1 reasoning).
|
||||||
|
|
||||||
|
Reasoning hygiene (R1-aware):
|
||||||
|
- If the model surfaces explicit `<think>` blocks, keep that internal scratch focused — do NOT restate the user's question inside it; jump straight to the analysis.
|
||||||
|
- Never paste the contents of `<think>` into your final answer. Final answer should reflect only the conclusion, citations, and any user-facing rationale.
|
||||||
|
- Do not let chain-of-thought leak into tool-call arguments — keep tool inputs minimal and structural.
|
||||||
|
|
||||||
|
Output style:
|
||||||
|
- Be concise. Default to a one-paragraph answer; expand only when the user asks for detail.
|
||||||
|
- Don't open with sycophantic phrasing ("Great question", "Sure, here you go"). Lead with the answer or the next action.
|
||||||
|
- For factual answers, cite once with `[citation:chunk_id]` and stop.
|
||||||
|
|
||||||
|
Tool calls:
|
||||||
|
- Issue independent tool calls in parallel within a single turn.
|
||||||
|
- Prefer the knowledge-base search tools before any web-search; this model has strong recall but stale training data.
|
||||||
|
- Don't fabricate file paths, chunk ids, or URLs — only use values returned by tools or provided by the user.
|
||||||
|
</provider_hints>
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
<provider_hints>
|
||||||
|
You are running on a Google Gemini model.
|
||||||
|
|
||||||
|
Output style:
|
||||||
|
- Concise & direct. Aim for fewer than 3 lines of prose (excluding tool output, citations, and code/snippets) when the task allows.
|
||||||
|
- No conversational filler — skip openers like "Okay, I will now…" and closers like "I have finished the changes…". Get straight to the action or answer.
|
||||||
|
- Format with GitHub-flavoured Markdown; assume monospace rendering.
|
||||||
|
- For one-line factual answers, just answer. No headers, no bullets.
|
||||||
|
|
||||||
|
Workflow for non-trivial tasks (Understand → Plan → Act → Verify):
|
||||||
|
1. **Understand:** read the user's request and the relevant KB / connector context. Use search and read tools (in parallel when independent) before assuming anything.
|
||||||
|
2. **Plan:** when the task touches multiple steps, share an extremely concise plan first.
|
||||||
|
3. **Act:** call the appropriate tools, strictly adhering to the prompts/routing already established for this agent.
|
||||||
|
4. **Verify:** confirm with a follow-up read or search where it materially de-risks the answer.
|
||||||
|
|
||||||
|
Discipline:
|
||||||
|
- Do not take significant actions beyond the clear scope of the user's request without confirming first.
|
||||||
|
- Do not assume a connector / tool / file exists — check (e.g. via `get_connected_accounts`) before referencing it.
|
||||||
|
- Path arguments must be the exact strings returned by tools; do not synthesise file paths.
|
||||||
|
</provider_hints>
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
<provider_hints>
|
||||||
|
You are running on an xAI Grok model.
|
||||||
|
|
||||||
|
Maximum terseness:
|
||||||
|
- Answer in fewer than 4 lines unless the user asks for detail. One-word answers are best when they suffice.
|
||||||
|
- No preamble ("The answer is", "Here's what I'll do"), no postamble ("Hope that helps", "Let me know"). Get straight to the answer.
|
||||||
|
- Avoid restating the user's question.
|
||||||
|
- For factual lookups inside the knowledge base, give the answer with a single `[citation:chunk_id]` and stop.
|
||||||
|
|
||||||
|
Tool discipline:
|
||||||
|
- Use exactly ONE tool per assistant turn when investigating; wait for the result before deciding the next call. Do not loop on the same tool with the same arguments — pick a result and act.
|
||||||
|
- For obviously parallelizable read-only batches (multiple independent searches), one turn with several tool calls is fine — but never chain into a fishing expedition.
|
||||||
|
|
||||||
|
Style:
|
||||||
|
- No emojis unless the user asked. No nested bullets, no headers for short answers.
|
||||||
|
- If you can't help, say so in 1-2 sentences without explaining "why this could lead to…".
|
||||||
|
</provider_hints>
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
<provider_hints>
|
||||||
|
You are running on a Moonshot Kimi model (Kimi-K1.5 / Kimi-K2 / Kimi-K2.5+).
|
||||||
|
|
||||||
|
Action bias:
|
||||||
|
- Default to taking action with tools rather than describing solutions in prose. If a tool can answer the question, call the tool.
|
||||||
|
- Don't narrate routine reads, searches, or obvious next steps. Combine related progress into one short status line.
|
||||||
|
- Be thorough in actions (test what you build, verify what you change). Be brief in explanations.
|
||||||
|
|
||||||
|
Tool calls:
|
||||||
|
- Output multiple non-interfering tool calls in a SINGLE response — parallelism is a major efficiency win on this model.
|
||||||
|
- When the `task` tool is available, delegate focused subtasks to a subagent with full context (subagents don't inherit yours).
|
||||||
|
- Don't apologise or pre-announce tool calls. The tool call itself is self-explanatory.
|
||||||
|
|
||||||
|
Language:
|
||||||
|
- Respond in the SAME language as the user's most recent turn unless explicitly instructed otherwise.
|
||||||
|
|
||||||
|
Discipline:
|
||||||
|
- Stay on track. Never give the user more than what they asked for.
|
||||||
|
- Fact-check before stating anything as factual; don't fabricate citations.
|
||||||
|
- Keep it stupidly simple. Don't overcomplicate.
|
||||||
|
</provider_hints>
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue