From 2e026549bebf31532bb18db6d1ccdde86acf0618 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 18 Apr 2026 16:22:20 +0530 Subject: [PATCH 001/299] chore: update alembic migration numbers --- ..._content_type.py => 127_add_report_content_type.py} | 10 +++++----- ...esume_prompt.py => 128_seed_build_resume_prompt.py} | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) rename surfsense_backend/alembic/versions/{126_add_report_content_type.py => 127_add_report_content_type.py} (87%) rename surfsense_backend/alembic/versions/{127_seed_build_resume_prompt.py => 128_seed_build_resume_prompt.py} (89%) diff --git a/surfsense_backend/alembic/versions/126_add_report_content_type.py b/surfsense_backend/alembic/versions/127_add_report_content_type.py similarity index 87% rename from surfsense_backend/alembic/versions/126_add_report_content_type.py rename to surfsense_backend/alembic/versions/127_add_report_content_type.py index 3d9e4860c..93bf471af 100644 --- a/surfsense_backend/alembic/versions/126_add_report_content_type.py +++ b/surfsense_backend/alembic/versions/127_add_report_content_type.py @@ -1,7 +1,7 @@ -"""126_add_report_content_type +"""127_add_report_content_type -Revision ID: 126 -Revises: 125 +Revision ID: 127 +Revises: 126 Create Date: 2026-04-15 Adds content_type column to reports table to distinguish between @@ -16,8 +16,8 @@ import sqlalchemy as sa from alembic import op -revision: str = "126" -down_revision: str | None = "125" +revision: str = "127" +down_revision: str | None = "126" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None diff --git a/surfsense_backend/alembic/versions/127_seed_build_resume_prompt.py b/surfsense_backend/alembic/versions/128_seed_build_resume_prompt.py similarity index 89% rename from surfsense_backend/alembic/versions/127_seed_build_resume_prompt.py rename to surfsense_backend/alembic/versions/128_seed_build_resume_prompt.py index 9e05a0510..886879a7b 100644 --- a/surfsense_backend/alembic/versions/127_seed_build_resume_prompt.py +++ b/surfsense_backend/alembic/versions/128_seed_build_resume_prompt.py @@ -1,7 +1,7 @@ -"""127_seed_build_resume_prompt +"""128_seed_build_resume_prompt -Revision ID: 127 -Revises: 126 +Revision ID: 128 +Revises: 127 Create Date: 2026-04-15 Seeds the 'Build Resume' default prompt for all existing users. @@ -16,8 +16,8 @@ import sqlalchemy as sa from alembic import op -revision: str = "127" -down_revision: str | None = "126" +revision: str = "128" +down_revision: str | None = "127" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None From 7fbd684b44201519a027a75398e8556941878029 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sun, 19 Apr 2026 23:48:18 +0530 Subject: [PATCH 002/299] feat: initialize Obsidian sample plugin with essential files and configurations --- surfsense_obsidian/.editorconfig | 10 + surfsense_obsidian/.github/workflows/lint.yml | 28 + surfsense_obsidian/.gitignore | 22 + surfsense_obsidian/.npmrc | 1 + surfsense_obsidian/AGENTS.md | 251 + surfsense_obsidian/LICENSE | 5 + surfsense_obsidian/README.md | 90 + surfsense_obsidian/esbuild.config.mjs | 49 + surfsense_obsidian/eslint.config.mts | 34 + surfsense_obsidian/manifest.json | 11 + surfsense_obsidian/package-lock.json | 5160 +++++++++++++++++ surfsense_obsidian/package.json | 29 + surfsense_obsidian/src/main.ts | 99 + surfsense_obsidian/src/settings.ts | 36 + surfsense_obsidian/styles.css | 8 + surfsense_obsidian/tsconfig.json | 30 + surfsense_obsidian/version-bump.mjs | 17 + surfsense_obsidian/versions.json | 3 + 18 files changed, 5883 insertions(+) create mode 100644 surfsense_obsidian/.editorconfig create mode 100644 surfsense_obsidian/.github/workflows/lint.yml create mode 100644 surfsense_obsidian/.gitignore create mode 100644 surfsense_obsidian/.npmrc create mode 100644 surfsense_obsidian/AGENTS.md create mode 100644 surfsense_obsidian/LICENSE create mode 100644 surfsense_obsidian/README.md create mode 100644 surfsense_obsidian/esbuild.config.mjs create mode 100644 surfsense_obsidian/eslint.config.mts create mode 100644 surfsense_obsidian/manifest.json create mode 100644 surfsense_obsidian/package-lock.json create mode 100644 surfsense_obsidian/package.json create mode 100644 surfsense_obsidian/src/main.ts create mode 100644 surfsense_obsidian/src/settings.ts create mode 100644 surfsense_obsidian/styles.css create mode 100644 surfsense_obsidian/tsconfig.json create mode 100644 surfsense_obsidian/version-bump.mjs create mode 100644 surfsense_obsidian/versions.json diff --git a/surfsense_obsidian/.editorconfig b/surfsense_obsidian/.editorconfig new file mode 100644 index 000000000..81f3ec354 --- /dev/null +++ b/surfsense_obsidian/.editorconfig @@ -0,0 +1,10 @@ +# top-most EditorConfig file +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +indent_style = tab +indent_size = 4 +tab_width = 4 diff --git a/surfsense_obsidian/.github/workflows/lint.yml b/surfsense_obsidian/.github/workflows/lint.yml new file mode 100644 index 000000000..7748ceb77 --- /dev/null +++ b/surfsense_obsidian/.github/workflows/lint.yml @@ -0,0 +1,28 @@ +name: Node.js build + +on: + push: + branches: ["**"] + pull_request: + branches: ["**"] + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + node-version: [20.x, 22.x] + # See supported Node.js release schedule at https://nodejs.org/en/about/releases/ + + steps: + - uses: actions/checkout@v4 + - name: Use Node.js ${{ matrix.node-version }} + uses: actions/setup-node@v4 + with: + node-version: ${{ matrix.node-version }} + cache: "npm" + - run: npm ci + - run: npm run build --if-present + - run: npm run lint + diff --git a/surfsense_obsidian/.gitignore b/surfsense_obsidian/.gitignore new file mode 100644 index 000000000..386ac2bdb --- /dev/null +++ b/surfsense_obsidian/.gitignore @@ -0,0 +1,22 @@ +# vscode +.vscode + +# Intellij +*.iml +.idea + +# npm +node_modules + +# Don't include the compiled main.js file in the repo. +# They should be uploaded to GitHub releases instead. +main.js + +# Exclude sourcemaps +*.map + +# obsidian +data.json + +# Exclude macOS Finder (System Explorer) View States +.DS_Store diff --git a/surfsense_obsidian/.npmrc b/surfsense_obsidian/.npmrc new file mode 100644 index 000000000..b9737525f --- /dev/null +++ b/surfsense_obsidian/.npmrc @@ -0,0 +1 @@ +tag-version-prefix="" \ No newline at end of file diff --git a/surfsense_obsidian/AGENTS.md b/surfsense_obsidian/AGENTS.md new file mode 100644 index 000000000..3f4274ac6 --- /dev/null +++ b/surfsense_obsidian/AGENTS.md @@ -0,0 +1,251 @@ +# Obsidian community plugin + +## Project overview + +- Target: Obsidian Community Plugin (TypeScript → bundled JavaScript). +- Entry point: `main.ts` compiled to `main.js` and loaded by Obsidian. +- Required release artifacts: `main.js`, `manifest.json`, and optional `styles.css`. + +## Environment & tooling + +- Node.js: use current LTS (Node 18+ recommended). +- **Package manager: npm** (required for this sample - `package.json` defines npm scripts and dependencies). +- **Bundler: esbuild** (required for this sample - `esbuild.config.mjs` and build scripts depend on it). Alternative bundlers like Rollup or webpack are acceptable for other projects if they bundle all external dependencies into `main.js`. +- Types: `obsidian` type definitions. + +**Note**: This sample project has specific technical dependencies on npm and esbuild. If you're creating a plugin from scratch, you can choose different tools, but you'll need to replace the build configuration accordingly. + +### Install + +```bash +npm install +``` + +### Dev (watch) + +```bash +npm run dev +``` + +### Production build + +```bash +npm run build +``` + +## Linting + +- To use eslint install eslint from terminal: `npm install -g eslint` +- To use eslint to analyze this project use this command: `eslint main.ts` +- eslint will then create a report with suggestions for code improvement by file and line number. +- If your source code is in a folder, such as `src`, you can use eslint with this command to analyze all files in that folder: `eslint ./src/` + +## File & folder conventions + +- **Organize code into multiple files**: Split functionality across separate modules rather than putting everything in `main.ts`. +- Source lives in `src/`. Keep `main.ts` small and focused on plugin lifecycle (loading, unloading, registering commands). +- **Example file structure**: + ``` + src/ + main.ts # Plugin entry point, lifecycle management + settings.ts # Settings interface and defaults + commands/ # Command implementations + command1.ts + command2.ts + ui/ # UI components, modals, views + modal.ts + view.ts + utils/ # Utility functions, helpers + helpers.ts + constants.ts + types.ts # TypeScript interfaces and types + ``` +- **Do not commit build artifacts**: Never commit `node_modules/`, `main.js`, or other generated files to version control. +- Keep the plugin small. Avoid large dependencies. Prefer browser-compatible packages. +- Generated output should be placed at the plugin root or `dist/` depending on your build setup. Release artifacts must end up at the top level of the plugin folder in the vault (`main.js`, `manifest.json`, `styles.css`). + +## Manifest rules (`manifest.json`) + +- Must include (non-exhaustive): + - `id` (plugin ID; for local dev it should match the folder name) + - `name` + - `version` (Semantic Versioning `x.y.z`) + - `minAppVersion` + - `description` + - `isDesktopOnly` (boolean) + - Optional: `author`, `authorUrl`, `fundingUrl` (string or map) +- Never change `id` after release. Treat it as stable API. +- Keep `minAppVersion` accurate when using newer APIs. +- Canonical requirements are coded here: https://github.com/obsidianmd/obsidian-releases/blob/master/.github/workflows/validate-plugin-entry.yml + +## Testing + +- Manual install for testing: copy `main.js`, `manifest.json`, `styles.css` (if any) to: + ``` + /.obsidian/plugins// + ``` +- Reload Obsidian and enable the plugin in **Settings → Community plugins**. + +## Commands & settings + +- Any user-facing commands should be added via `this.addCommand(...)`. +- If the plugin has configuration, provide a settings tab and sensible defaults. +- Persist settings using `this.loadData()` / `this.saveData()`. +- Use stable command IDs; avoid renaming once released. + +## Versioning & releases + +- Bump `version` in `manifest.json` (SemVer) and update `versions.json` to map plugin version → minimum app version. +- Create a GitHub release whose tag exactly matches `manifest.json`'s `version`. Do not use a leading `v`. +- Attach `manifest.json`, `main.js`, and `styles.css` (if present) to the release as individual assets. +- After the initial release, follow the process to add/update your plugin in the community catalog as required. + +## Security, privacy, and compliance + +Follow Obsidian's **Developer Policies** and **Plugin Guidelines**. In particular: + +- Default to local/offline operation. Only make network requests when essential to the feature. +- No hidden telemetry. If you collect optional analytics or call third-party services, require explicit opt-in and document clearly in `README.md` and in settings. +- Never execute remote code, fetch and eval scripts, or auto-update plugin code outside of normal releases. +- Minimize scope: read/write only what's necessary inside the vault. Do not access files outside the vault. +- Clearly disclose any external services used, data sent, and risks. +- Respect user privacy. Do not collect vault contents, filenames, or personal information unless absolutely necessary and explicitly consented. +- Avoid deceptive patterns, ads, or spammy notifications. +- Register and clean up all DOM, app, and interval listeners using the provided `register*` helpers so the plugin unloads safely. + +## UX & copy guidelines (for UI text, commands, settings) + +- Prefer sentence case for headings, buttons, and titles. +- Use clear, action-oriented imperatives in step-by-step copy. +- Use **bold** to indicate literal UI labels. Prefer "select" for interactions. +- Use arrow notation for navigation: **Settings → Community plugins**. +- Keep in-app strings short, consistent, and free of jargon. + +## Performance + +- Keep startup light. Defer heavy work until needed. +- Avoid long-running tasks during `onload`; use lazy initialization. +- Batch disk access and avoid excessive vault scans. +- Debounce/throttle expensive operations in response to file system events. + +## Coding conventions + +- TypeScript with `"strict": true` preferred. +- **Keep `main.ts` minimal**: Focus only on plugin lifecycle (onload, onunload, addCommand calls). Delegate all feature logic to separate modules. +- **Split large files**: If any file exceeds ~200-300 lines, consider breaking it into smaller, focused modules. +- **Use clear module boundaries**: Each file should have a single, well-defined responsibility. +- Bundle everything into `main.js` (no unbundled runtime deps). +- Avoid Node/Electron APIs if you want mobile compatibility; set `isDesktopOnly` accordingly. +- Prefer `async/await` over promise chains; handle errors gracefully. + +## Mobile + +- Where feasible, test on iOS and Android. +- Don't assume desktop-only behavior unless `isDesktopOnly` is `true`. +- Avoid large in-memory structures; be mindful of memory and storage constraints. + +## Agent do/don't + +**Do** +- Add commands with stable IDs (don't rename once released). +- Provide defaults and validation in settings. +- Write idempotent code paths so reload/unload doesn't leak listeners or intervals. +- Use `this.register*` helpers for everything that needs cleanup. + +**Don't** +- Introduce network calls without an obvious user-facing reason and documentation. +- Ship features that require cloud services without clear disclosure and explicit opt-in. +- Store or transmit vault contents unless essential and consented. + +## Common tasks + +### Organize code across multiple files + +**main.ts** (minimal, lifecycle only): +```ts +import { Plugin } from "obsidian"; +import { MySettings, DEFAULT_SETTINGS } from "./settings"; +import { registerCommands } from "./commands"; + +export default class MyPlugin extends Plugin { + settings: MySettings; + + async onload() { + this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData()); + registerCommands(this); + } +} +``` + +**settings.ts**: +```ts +export interface MySettings { + enabled: boolean; + apiKey: string; +} + +export const DEFAULT_SETTINGS: MySettings = { + enabled: true, + apiKey: "", +}; +``` + +**commands/index.ts**: +```ts +import { Plugin } from "obsidian"; +import { doSomething } from "./my-command"; + +export function registerCommands(plugin: Plugin) { + plugin.addCommand({ + id: "do-something", + name: "Do something", + callback: () => doSomething(plugin), + }); +} +``` + +### Add a command + +```ts +this.addCommand({ + id: "your-command-id", + name: "Do the thing", + callback: () => this.doTheThing(), +}); +``` + +### Persist settings + +```ts +interface MySettings { enabled: boolean } +const DEFAULT_SETTINGS: MySettings = { enabled: true }; + +async onload() { + this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData()); + await this.saveData(this.settings); +} +``` + +### Register listeners safely + +```ts +this.registerEvent(this.app.workspace.on("file-open", f => { /* ... */ })); +this.registerDomEvent(window, "resize", () => { /* ... */ }); +this.registerInterval(window.setInterval(() => { /* ... */ }, 1000)); +``` + +## Troubleshooting + +- Plugin doesn't load after build: ensure `main.js` and `manifest.json` are at the top level of the plugin folder under `/.obsidian/plugins//`. +- Build issues: if `main.js` is missing, run `npm run build` or `npm run dev` to compile your TypeScript source code. +- Commands not appearing: verify `addCommand` runs after `onload` and IDs are unique. +- Settings not persisting: ensure `loadData`/`saveData` are awaited and you re-render the UI after changes. +- Mobile-only issues: confirm you're not using desktop-only APIs; check `isDesktopOnly` and adjust. + +## References + +- Obsidian sample plugin: https://github.com/obsidianmd/obsidian-sample-plugin +- API documentation: https://docs.obsidian.md +- Developer policies: https://docs.obsidian.md/Developer+policies +- Plugin guidelines: https://docs.obsidian.md/Plugins/Releasing/Plugin+guidelines +- Style guide: https://help.obsidian.md/style-guide diff --git a/surfsense_obsidian/LICENSE b/surfsense_obsidian/LICENSE new file mode 100644 index 000000000..287f37a72 --- /dev/null +++ b/surfsense_obsidian/LICENSE @@ -0,0 +1,5 @@ +Copyright (C) 2020-2025 by Dynalist Inc. + +Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/surfsense_obsidian/README.md b/surfsense_obsidian/README.md new file mode 100644 index 000000000..8ffa20efe --- /dev/null +++ b/surfsense_obsidian/README.md @@ -0,0 +1,90 @@ +# Obsidian Sample Plugin + +This is a sample plugin for Obsidian (https://obsidian.md). + +This project uses TypeScript to provide type checking and documentation. +The repo depends on the latest plugin API (obsidian.d.ts) in TypeScript Definition format, which contains TSDoc comments describing what it does. + +This sample plugin demonstrates some of the basic functionality the plugin API can do. +- Adds a ribbon icon, which shows a Notice when clicked. +- Adds a command "Open modal (simple)" which opens a Modal. +- Adds a plugin setting tab to the settings page. +- Registers a global click event and output 'click' to the console. +- Registers a global interval which logs 'setInterval' to the console. + +## First time developing plugins? + +Quick starting guide for new plugin devs: + +- Check if [someone already developed a plugin for what you want](https://obsidian.md/plugins)! There might be an existing plugin similar enough that you can partner up with. +- Make a copy of this repo as a template with the "Use this template" button (login to GitHub if you don't see it). +- Clone your repo to a local development folder. For convenience, you can place this folder in your `.obsidian/plugins/your-plugin-name` folder. +- Install NodeJS, then run `npm i` in the command line under your repo folder. +- Run `npm run dev` to compile your plugin from `main.ts` to `main.js`. +- Make changes to `main.ts` (or create new `.ts` files). Those changes should be automatically compiled into `main.js`. +- Reload Obsidian to load the new version of your plugin. +- Enable plugin in settings window. +- For updates to the Obsidian API run `npm update` in the command line under your repo folder. + +## Releasing new releases + +- Update your `manifest.json` with your new version number, such as `1.0.1`, and the minimum Obsidian version required for your latest release. +- Update your `versions.json` file with `"new-plugin-version": "minimum-obsidian-version"` so older versions of Obsidian can download an older version of your plugin that's compatible. +- Create new GitHub release using your new version number as the "Tag version". Use the exact version number, don't include a prefix `v`. See here for an example: https://github.com/obsidianmd/obsidian-sample-plugin/releases +- Upload the files `manifest.json`, `main.js`, `styles.css` as binary attachments. Note: The manifest.json file must be in two places, first the root path of your repository and also in the release. +- Publish the release. + +> You can simplify the version bump process by running `npm version patch`, `npm version minor` or `npm version major` after updating `minAppVersion` manually in `manifest.json`. +> The command will bump version in `manifest.json` and `package.json`, and add the entry for the new version to `versions.json` + +## Adding your plugin to the community plugin list + +- Check the [plugin guidelines](https://docs.obsidian.md/Plugins/Releasing/Plugin+guidelines). +- Publish an initial version. +- Make sure you have a `README.md` file in the root of your repo. +- Make a pull request at https://github.com/obsidianmd/obsidian-releases to add your plugin. + +## How to use + +- Clone this repo. +- Make sure your NodeJS is at least v16 (`node --version`). +- `npm i` or `yarn` to install dependencies. +- `npm run dev` to start compilation in watch mode. + +## Manually installing the plugin + +- Copy over `main.js`, `styles.css`, `manifest.json` to your vault `VaultFolder/.obsidian/plugins/your-plugin-id/`. + +## Improve code quality with eslint +- [ESLint](https://eslint.org/) is a tool that analyzes your code to quickly find problems. You can run ESLint against your plugin to find common bugs and ways to improve your code. +- This project already has eslint preconfigured, you can invoke a check by running`npm run lint` +- Together with a custom eslint [plugin](https://github.com/obsidianmd/eslint-plugin) for Obsidan specific code guidelines. +- A GitHub action is preconfigured to automatically lint every commit on all branches. + +## Funding URL + +You can include funding URLs where people who use your plugin can financially support it. + +The simple way is to set the `fundingUrl` field to your link in your `manifest.json` file: + +```json +{ + "fundingUrl": "https://buymeacoffee.com" +} +``` + +If you have multiple URLs, you can also do: + +```json +{ + "fundingUrl": { + "Buy Me a Coffee": "https://buymeacoffee.com", + "GitHub Sponsor": "https://github.com/sponsors", + "Patreon": "https://www.patreon.com/" + } +} +``` + +## API Documentation + +See https://docs.obsidian.md diff --git a/surfsense_obsidian/esbuild.config.mjs b/surfsense_obsidian/esbuild.config.mjs new file mode 100644 index 000000000..1c74a149e --- /dev/null +++ b/surfsense_obsidian/esbuild.config.mjs @@ -0,0 +1,49 @@ +import esbuild from "esbuild"; +import process from "process"; +import { builtinModules } from 'node:module'; + +const banner = +`/* +THIS IS A GENERATED/BUNDLED FILE BY ESBUILD +if you want to view the source, please visit the github repository of this plugin +*/ +`; + +const prod = (process.argv[2] === "production"); + +const context = await esbuild.context({ + banner: { + js: banner, + }, + entryPoints: ["src/main.ts"], + bundle: true, + external: [ + "obsidian", + "electron", + "@codemirror/autocomplete", + "@codemirror/collab", + "@codemirror/commands", + "@codemirror/language", + "@codemirror/lint", + "@codemirror/search", + "@codemirror/state", + "@codemirror/view", + "@lezer/common", + "@lezer/highlight", + "@lezer/lr", + ...builtinModules], + format: "cjs", + target: "es2018", + logLevel: "info", + sourcemap: prod ? false : "inline", + treeShaking: true, + outfile: "main.js", + minify: prod, +}); + +if (prod) { + await context.rebuild(); + process.exit(0); +} else { + await context.watch(); +} diff --git a/surfsense_obsidian/eslint.config.mts b/surfsense_obsidian/eslint.config.mts new file mode 100644 index 000000000..3062c4a07 --- /dev/null +++ b/surfsense_obsidian/eslint.config.mts @@ -0,0 +1,34 @@ +import tseslint from 'typescript-eslint'; +import obsidianmd from "eslint-plugin-obsidianmd"; +import globals from "globals"; +import { globalIgnores } from "eslint/config"; + +export default tseslint.config( + { + languageOptions: { + globals: { + ...globals.browser, + }, + parserOptions: { + projectService: { + allowDefaultProject: [ + 'eslint.config.js', + 'manifest.json' + ] + }, + tsconfigRootDir: import.meta.dirname, + extraFileExtensions: ['.json'] + }, + }, + }, + ...obsidianmd.configs.recommended, + globalIgnores([ + "node_modules", + "dist", + "esbuild.config.mjs", + "eslint.config.js", + "version-bump.mjs", + "versions.json", + "main.js", + ]), +); diff --git a/surfsense_obsidian/manifest.json b/surfsense_obsidian/manifest.json new file mode 100644 index 000000000..dfa940ed8 --- /dev/null +++ b/surfsense_obsidian/manifest.json @@ -0,0 +1,11 @@ +{ + "id": "sample-plugin", + "name": "Sample Plugin", + "version": "1.0.0", + "minAppVersion": "0.15.0", + "description": "Demonstrates some of the capabilities of the Obsidian API.", + "author": "Obsidian", + "authorUrl": "https://obsidian.md", + "fundingUrl": "https://obsidian.md/pricing", + "isDesktopOnly": false +} diff --git a/surfsense_obsidian/package-lock.json b/surfsense_obsidian/package-lock.json new file mode 100644 index 000000000..d0dac397c --- /dev/null +++ b/surfsense_obsidian/package-lock.json @@ -0,0 +1,5160 @@ +{ + "name": "obsidian-sample-plugin", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "obsidian-sample-plugin", + "version": "1.0.0", + "license": "0-BSD", + "dependencies": { + "obsidian": "latest" + }, + "devDependencies": { + "@eslint/js": "9.30.1", + "@types/node": "^16.11.6", + "esbuild": "0.25.5", + "eslint-plugin-obsidianmd": "0.1.9", + "globals": "14.0.0", + "jiti": "2.6.1", + "tslib": "2.4.0", + "typescript": "^5.8.3", + "typescript-eslint": "8.35.1" + } + }, + "node_modules/@codemirror/state": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.5.0.tgz", + "integrity": "sha512-MwBHVK60IiIHDcoMet78lxt6iw5gJOGSbNbOIVBHWVXIH4/Nq1+GQgLLGgI1KlnN86WDXsPudVaqYHKBIx7Eyw==", + "license": "MIT", + "peer": true, + "dependencies": { + "@marijn/find-cluster-break": "^1.0.0" + } + }, + "node_modules/@codemirror/view": { + "version": "6.38.6", + "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.38.6.tgz", + "integrity": "sha512-qiS0z1bKs5WOvHIAC0Cybmv4AJSkAXgX5aD6Mqd2epSLlVJsQl8NG23jCVouIgkh4All/mrbdsf2UOLFnJw0tw==", + "license": "MIT", + "peer": true, + "dependencies": { + "@codemirror/state": "^6.5.0", + "crelt": "^1.0.6", + "style-mod": "^4.1.0", + "w3c-keyname": "^2.2.4" + } + }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.5.tgz", + "integrity": "sha512-9o3TMmpmftaCMepOdA5k/yDw8SfInyzWWTjYTFCX3kPSDJMROQTb8jg+h9Cnwnmm1vOzvxN7gIfB5V2ewpjtGA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.5.tgz", + "integrity": "sha512-AdJKSPeEHgi7/ZhuIPtcQKr5RQdo6OO2IL87JkianiMYMPbCtot9fxPbrMiBADOWWm3T2si9stAiVsGbTQFkbA==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.5.tgz", + "integrity": "sha512-VGzGhj4lJO+TVGV1v8ntCZWJktV7SGCs3Pn1GRWI1SBFtRALoomm8k5E9Pmwg3HOAal2VDc2F9+PM/rEY6oIDg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.5.tgz", + "integrity": "sha512-D2GyJT1kjvO//drbRT3Hib9XPwQeWd9vZoBJn+bu/lVsOZ13cqNdDeqIF/xQ5/VmWvMduP6AmXvylO/PIc2isw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.5.tgz", + "integrity": "sha512-GtaBgammVvdF7aPIgH2jxMDdivezgFu6iKpmT+48+F8Hhg5J/sfnDieg0aeG/jfSvkYQU2/pceFPDKlqZzwnfQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.5.tgz", + "integrity": "sha512-1iT4FVL0dJ76/q1wd7XDsXrSW+oLoquptvh4CLR4kITDtqi2e/xwXwdCVH8hVHU43wgJdsq7Gxuzcs6Iq/7bxQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.5.tgz", + "integrity": "sha512-nk4tGP3JThz4La38Uy/gzyXtpkPW8zSAmoUhK9xKKXdBCzKODMc2adkB2+8om9BDYugz+uGV7sLmpTYzvmz6Sw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.5.tgz", + "integrity": "sha512-PrikaNjiXdR2laW6OIjlbeuCPrPaAl0IwPIaRv+SMV8CiM8i2LqVUHFC1+8eORgWyY7yhQY+2U2fA55mBzReaw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.5.tgz", + "integrity": "sha512-cPzojwW2okgh7ZlRpcBEtsX7WBuqbLrNXqLU89GxWbNt6uIg78ET82qifUy3W6OVww6ZWobWub5oqZOVtwolfw==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.5.tgz", + "integrity": "sha512-Z9kfb1v6ZlGbWj8EJk9T6czVEjjq2ntSYLY2cw6pAZl4oKtfgQuS4HOq41M/BcoLPzrUbNd+R4BXFyH//nHxVg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.5.tgz", + "integrity": "sha512-sQ7l00M8bSv36GLV95BVAdhJ2QsIbCuCjh/uYrWiMQSUuV+LpXwIqhgJDcvMTj+VsQmqAHL2yYaasENvJ7CDKA==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.5.tgz", + "integrity": "sha512-0ur7ae16hDUC4OL5iEnDb0tZHDxYmuQyhKhsPBV8f99f6Z9KQM02g33f93rNH5A30agMS46u2HP6qTdEt6Q1kg==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.5.tgz", + "integrity": "sha512-kB/66P1OsHO5zLz0i6X0RxlQ+3cu0mkxS3TKFvkb5lin6uwZ/ttOkP3Z8lfR9mJOBk14ZwZ9182SIIWFGNmqmg==", + "cpu": [ + "mips64el" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.5.tgz", + "integrity": "sha512-UZCmJ7r9X2fe2D6jBmkLBMQetXPXIsZjQJCjgwpVDz+YMcS6oFR27alkgGv3Oqkv07bxdvw7fyB71/olceJhkQ==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.5.tgz", + "integrity": "sha512-kTxwu4mLyeOlsVIFPfQo+fQJAV9mh24xL+y+Bm6ej067sYANjyEw1dNHmvoqxJUCMnkBdKpvOn0Ahql6+4VyeA==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.5.tgz", + "integrity": "sha512-K2dSKTKfmdh78uJ3NcWFiqyRrimfdinS5ErLSn3vluHNeHVnBAFWC8a4X5N+7FgVE1EjXS1QDZbpqZBjfrqMTQ==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.5.tgz", + "integrity": "sha512-uhj8N2obKTE6pSZ+aMUbqq+1nXxNjZIIjCjGLfsWvVpy7gKCOL6rsY1MhRh9zLtUtAI7vpgLMK6DxjO8Qm9lJw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.5.tgz", + "integrity": "sha512-pwHtMP9viAy1oHPvgxtOv+OkduK5ugofNTVDilIzBLpoWAM16r7b/mxBvfpuQDpRQFMfuVr5aLcn4yveGvBZvw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.5.tgz", + "integrity": "sha512-WOb5fKrvVTRMfWFNCroYWWklbnXH0Q5rZppjq0vQIdlsQKuw6mdSihwSo4RV/YdQ5UCKKvBy7/0ZZYLBZKIbwQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.5.tgz", + "integrity": "sha512-7A208+uQKgTxHd0G0uqZO8UjK2R0DDb4fDmERtARjSHWxqMTye4Erz4zZafx7Di9Cv+lNHYuncAkiGFySoD+Mw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.5.tgz", + "integrity": "sha512-G4hE405ErTWraiZ8UiSoesH8DaCsMm0Cay4fsFWOOUcz8b8rC6uCvnagr+gnioEjWn0wC+o1/TAHt+It+MpIMg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.5.tgz", + "integrity": "sha512-l+azKShMy7FxzY0Rj4RCt5VD/q8mG/e+mDivgspo+yL8zW7qEwctQ6YqKX34DTEleFAvCIUviCFX1SDZRSyMQA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.5.tgz", + "integrity": "sha512-O2S7SNZzdcFG7eFKgvwUEZ2VG9D/sn/eIiz8XRZ1Q/DO5a3s76Xv0mdBzVM5j5R639lXQmPmSo0iRpHqUUrsxw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.5.tgz", + "integrity": "sha512-onOJ02pqs9h1iMJ1PQphR+VZv8qBMQ77Klcsqv9CNW2w6yLqoURLcgERAIurY6QE63bbLuqgP9ATqajFLK5AMQ==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.5.tgz", + "integrity": "sha512-TXv6YnJ8ZMVdX+SXWVBo/0p8LTcrUYngpWjvm91TMjjBQii7Oz11Lw5lbDV5Y0TzuhSJHwiH4hEtC1I42mMS0g==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@eslint-community/eslint-utils": { + "version": "4.9.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.9.0.tgz", + "integrity": "sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.4.3" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" + } + }, + "node_modules/@eslint-community/eslint-utils/node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint-community/regexpp": { + "version": "4.12.2", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.2.tgz", + "integrity": "sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, + "node_modules/@eslint/config-array": { + "version": "0.21.1", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.21.1.tgz", + "integrity": "sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/object-schema": "^2.1.7", + "debug": "^4.3.1", + "minimatch": "^3.1.2" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/config-helpers": { + "version": "0.4.2", + "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.4.2.tgz", + "integrity": "sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.17.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/core": { + "version": "0.17.0", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.17.0.tgz", + "integrity": "sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@types/json-schema": "^7.0.15" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/eslintrc": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.1.tgz", + "integrity": "sha512-gtF186CXhIl1p4pJNGZw8Yc6RlshoePRvE0X91oPGb3vZ8pM3qOS9W9NGPat9LziaBV7XrJWGylNQXkGcnM3IQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ajv": "^6.12.4", + "debug": "^4.3.2", + "espree": "^10.0.1", + "globals": "^14.0.0", + "ignore": "^5.2.0", + "import-fresh": "^3.2.1", + "js-yaml": "^4.1.0", + "minimatch": "^3.1.2", + "strip-json-comments": "^3.1.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint/js": { + "version": "9.30.1", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.30.1.tgz", + "integrity": "sha512-zXhuECFlyep42KZUhWjfvsmXGX39W8K8LFb8AWXM9gSV9dQB+MrJGLKvW6Zw0Ggnbpw0VHTtrhFXYe3Gym18jg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + } + }, + "node_modules/@eslint/json": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/@eslint/json/-/json-0.14.0.tgz", + "integrity": "sha512-rvR/EZtvUG3p9uqrSmcDJPYSH7atmWr0RnFWN6m917MAPx82+zQgPUmDu0whPFG6XTyM0vB/hR6c1Q63OaYtCQ==", + "dev": true, + "license": "Apache-2.0", + "peer": true, + "dependencies": { + "@eslint/core": "^0.17.0", + "@eslint/plugin-kit": "^0.4.1", + "@humanwhocodes/momoa": "^3.3.10", + "natural-compare": "^1.4.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/object-schema": { + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.7.tgz", + "integrity": "sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/plugin-kit": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.4.1.tgz", + "integrity": "sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.17.0", + "levn": "^0.4.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@humanfs/core": { + "version": "0.19.1", + "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", + "integrity": "sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node": { + "version": "0.16.7", + "resolved": "https://registry.npmjs.org/@humanfs/node/-/node-0.16.7.tgz", + "integrity": "sha512-/zUx+yOsIrG4Y43Eh2peDeKCxlRt/gET6aHfaKpuq267qXdYDFViVHfMaLyygZOnl0kGWxFIgsBy8QFuTLUXEQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@humanfs/core": "^0.19.1", + "@humanwhocodes/retry": "^0.4.0" + }, + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanwhocodes/module-importer": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz", + "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.22" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/momoa": { + "version": "3.3.10", + "resolved": "https://registry.npmjs.org/@humanwhocodes/momoa/-/momoa-3.3.10.tgz", + "integrity": "sha512-KWiFQpSAqEIyrTXko3hFNLeQvSK8zXlJQzhhxsyVn58WFRYXST99b3Nqnu+ttOtjds2Pl2grUHGpe2NzhPynuQ==", + "dev": true, + "license": "Apache-2.0", + "peer": true, + "engines": { + "node": ">=18" + } + }, + "node_modules/@humanwhocodes/retry": { + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.3.tgz", + "integrity": "sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@marijn/find-cluster-break": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz", + "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==", + "license": "MIT", + "peer": true + }, + "node_modules/@microsoft/eslint-plugin-sdl": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@microsoft/eslint-plugin-sdl/-/eslint-plugin-sdl-1.1.0.tgz", + "integrity": "sha512-dxdNHOemLnBhfY3eByrujX9KyLigcNtW8sU+axzWv5nLGcsSBeKW2YYyTpfPo1hV8YPOmIGnfA4fZHyKVtWqBQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-plugin-n": "17.10.3", + "eslint-plugin-react": "7.37.3", + "eslint-plugin-security": "1.4.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "peerDependencies": { + "eslint": "^9" + } + }, + "node_modules/@microsoft/eslint-plugin-sdl/node_modules/eslint-plugin-security": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-security/-/eslint-plugin-security-1.4.0.tgz", + "integrity": "sha512-xlS7P2PLMXeqfhyf3NpqbvbnW04kN8M9NtmhpR3XGyOvt/vNKS7XPXT5EDbwKW9vCjWH4PpfQvgD/+JgN0VJKA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "safe-regex": "^1.1.0" + } + }, + "node_modules/@microsoft/eslint-plugin-sdl/node_modules/safe-regex": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/safe-regex/-/safe-regex-1.1.0.tgz", + "integrity": "sha512-aJXcif4xnaNUzvUuC5gcb46oTS7zvg4jpMTnuqtrEPlR3vFr4pxtdTwaF1Qs3Enjn9HK+ZlwQui+a7z0SywIzg==", + "dev": true, + "license": "MIT", + "dependencies": { + "ret": "~0.1.10" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@pkgr/core": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.1.2.tgz", + "integrity": "sha512-fdDH1LSGfZdTH2sxdpVMw31BanV28K/Gry0cVFxaNP77neJSkd82mM8ErPNYs9e+0O7SdHBLTDzDgwUuy18RnQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.20.0 || ^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/unts" + } + }, + "node_modules/@rtsao/scc": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@rtsao/scc/-/scc-1.1.0.tgz", + "integrity": "sha512-zt6OdqaDoOnJ1ZYsCYGt9YmWzDXl4vQdKTyJev62gFhRGKdx7mcT54V9KIjg+d2wi9EXsPvAPKe7i7WjfVWB8g==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/codemirror": { + "version": "5.60.8", + "resolved": "https://registry.npmjs.org/@types/codemirror/-/codemirror-5.60.8.tgz", + "integrity": "sha512-VjFgDF/eB+Aklcy15TtOTLQeMjTo07k7KAjql8OK5Dirr7a6sJY4T1uVBDuTVG9VEmn1uUsohOpYnVfgC6/jyw==", + "license": "MIT", + "dependencies": { + "@types/tern": "*" + } + }, + "node_modules/@types/eslint": { + "version": "8.56.2", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.2.tgz", + "integrity": "sha512-uQDwm1wFHmbBbCZCqAlq6Do9LYwByNZHWzXppSnay9SuwJ+VRbjkbLABer54kcPnMSlG6Fdiy2yaFXm/z9Z5gw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/estree": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", + "license": "MIT" + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/json5": { + "version": "0.0.29", + "resolved": "https://registry.npmjs.org/@types/json5/-/json5-0.0.29.tgz", + "integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "16.18.126", + "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.126.tgz", + "integrity": "sha512-OTcgaiwfGFBKacvfwuHzzn1KLxH/er8mluiy8/uM3sGXHaRe73RrSIj01jow9t4kJEW633Ov+cOexXeiApTyAw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/tern": { + "version": "0.23.9", + "resolved": "https://registry.npmjs.org/@types/tern/-/tern-0.23.9.tgz", + "integrity": "sha512-ypzHFE/wBzh+BlH6rrBgS5I/Z7RD21pGhZ2rltb/+ZrVM1awdZwjx7hE5XfuYgHWk9uvV5HLZN3SloevCAp3Bw==", + "license": "MIT", + "dependencies": { + "@types/estree": "*" + } + }, + "node_modules/@typescript-eslint/eslint-plugin": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.35.1.tgz", + "integrity": "sha512-9XNTlo7P7RJxbVeICaIIIEipqxLKguyh+3UbXuT2XQuFp6d8VOeDEGuz5IiX0dgZo8CiI6aOFLg4e8cF71SFVg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/regexpp": "^4.10.0", + "@typescript-eslint/scope-manager": "8.35.1", + "@typescript-eslint/type-utils": "8.35.1", + "@typescript-eslint/utils": "8.35.1", + "@typescript-eslint/visitor-keys": "8.35.1", + "graphemer": "^1.4.0", + "ignore": "^7.0.0", + "natural-compare": "^1.4.0", + "ts-api-utils": "^2.1.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "@typescript-eslint/parser": "^8.35.1", + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/eslint-plugin/node_modules/ignore": { + "version": "7.0.5", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.5.tgz", + "integrity": "sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/@typescript-eslint/parser": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.35.1.tgz", + "integrity": "sha512-3MyiDfrfLeK06bi/g9DqJxP5pV74LNv4rFTyvGDmT3x2p1yp1lOd+qYZfiRPIOf/oON+WRZR5wxxuF85qOar+w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/scope-manager": "8.35.1", + "@typescript-eslint/types": "8.35.1", + "@typescript-eslint/typescript-estree": "8.35.1", + "@typescript-eslint/visitor-keys": "8.35.1", + "debug": "^4.3.4" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/project-service": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.35.1.tgz", + "integrity": "sha512-VYxn/5LOpVxADAuP3NrnxxHYfzVtQzLKeldIhDhzC8UHaiQvYlXvKuVho1qLduFbJjjy5U5bkGwa3rUGUb1Q6Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/tsconfig-utils": "^8.35.1", + "@typescript-eslint/types": "^8.35.1", + "debug": "^4.3.4" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.35.1.tgz", + "integrity": "sha512-s/Bpd4i7ht2934nG+UoSPlYXd08KYz3bmjLEb7Ye1UVob0d1ENiT3lY8bsCmik4RqfSbPw9xJJHbugpPpP5JUg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.35.1", + "@typescript-eslint/visitor-keys": "8.35.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/tsconfig-utils": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.35.1.tgz", + "integrity": "sha512-K5/U9VmT9dTHoNowWZpz+/TObS3xqC5h0xAIjXPw+MNcKV9qg6eSatEnmeAwkjHijhACH0/N7bkhKvbt1+DXWQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/type-utils": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.35.1.tgz", + "integrity": "sha512-HOrUBlfVRz5W2LIKpXzZoy6VTZzMu2n8q9C2V/cFngIC5U1nStJgv0tMV4sZPzdf4wQm9/ToWUFPMN9Vq9VJQQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/typescript-estree": "8.35.1", + "@typescript-eslint/utils": "8.35.1", + "debug": "^4.3.4", + "ts-api-utils": "^2.1.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/types": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.35.1.tgz", + "integrity": "sha512-q/O04vVnKHfrrhNAscndAn1tuQhIkwqnaW+eu5waD5IPts2eX1dgJxgqcPx5BX109/qAz7IG6VrEPTOYKCNfRQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.35.1.tgz", + "integrity": "sha512-Vvpuvj4tBxIka7cPs6Y1uvM7gJgdF5Uu9F+mBJBPY4MhvjrjWGK4H0lVgLJd/8PWZ23FTqsaJaLEkBCFUk8Y9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/project-service": "8.35.1", + "@typescript-eslint/tsconfig-utils": "8.35.1", + "@typescript-eslint/types": "8.35.1", + "@typescript-eslint/visitor-keys": "8.35.1", + "debug": "^4.3.4", + "fast-glob": "^3.3.2", + "is-glob": "^4.0.3", + "minimatch": "^9.0.4", + "semver": "^7.6.0", + "ts-api-utils": "^2.1.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/@typescript-eslint/typescript-estree/node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.35.1.tgz", + "integrity": "sha512-lhnwatFmOFcazAsUm3ZnZFpXSxiwoa1Lj50HphnDe1Et01NF4+hrdXONSUHIcbVu2eFb1bAf+5yjXkGVkXBKAQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.7.0", + "@typescript-eslint/scope-manager": "8.35.1", + "@typescript-eslint/types": "8.35.1", + "@typescript-eslint/typescript-estree": "8.35.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.35.1.tgz", + "integrity": "sha512-VRwixir4zBWCSTP/ljEo091lbpypz57PoeAQ9imjG+vbeof9LplljsL1mos4ccG6H9IjfrVGM359RozUnuFhpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.35.1", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "dev": true, + "license": "Python-2.0" + }, + "node_modules/array-buffer-byte-length": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/array-buffer-byte-length/-/array-buffer-byte-length-1.0.2.tgz", + "integrity": "sha512-LHE+8BuR7RYGDKvnrmcuSq3tDcKv9OFEXQt/HpbZhY7V6h0zlUXutnAD82GiFx9rdieCMjkvtcsPqBwgUl1Iiw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "is-array-buffer": "^3.0.5" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array-includes": { + "version": "3.1.9", + "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.9.tgz", + "integrity": "sha512-FmeCCAenzH0KH381SPT5FZmiA/TmpndpcaShhfgEN9eCVjnFBqq3l1xrI42y8+PPLI6hypzou4GXw00WHmPBLQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "define-properties": "^1.2.1", + "es-abstract": "^1.24.0", + "es-object-atoms": "^1.1.1", + "get-intrinsic": "^1.3.0", + "is-string": "^1.1.1", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.findlast": { + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/array.prototype.findlast/-/array.prototype.findlast-1.2.5.tgz", + "integrity": "sha512-CVvd6FHg1Z3POpBLxO6E6zr+rSKEQ9L6rZHAaY7lLfhKsWYUBBOuMs0e9o24oopj6H+geRCX0YJ+TJLBK2eHyQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.2", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.findlastindex": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/array.prototype.findlastindex/-/array.prototype.findlastindex-1.2.6.tgz", + "integrity": "sha512-F/TKATkzseUExPlfvmwQKGITM3DGTK+vkAsCZoDc5daVygbJBnjEUCbgkAvVFsgfXfX4YIqZ/27G3k3tdXrTxQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.9", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "es-shim-unscopables": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.flat": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/array.prototype.flat/-/array.prototype.flat-1.3.3.tgz", + "integrity": "sha512-rwG/ja1neyLqCuGZ5YYrznA62D4mZXg0i1cIskIUKSiqF3Cje9/wXAls9B9s1Wa2fomMsIv8czB8jZcPmxCXFg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.flatmap": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/array.prototype.flatmap/-/array.prototype.flatmap-1.3.3.tgz", + "integrity": "sha512-Y7Wt51eKJSyi80hFrJCePGGNo5ktJCslFuboqJsbf57CCPcm5zztluPlc4/aD8sWsKvlwatezpV4U1efk8kpjg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.tosorted": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/array.prototype.tosorted/-/array.prototype.tosorted-1.1.4.tgz", + "integrity": "sha512-p6Fx8B7b7ZhL/gmUsAy0D15WhvDccw3mnGNbZpi3pmeJdxtWsj2jEaI4Y6oo3XiHfzuSgPwKc04MYt6KgvC/wA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.3", + "es-errors": "^1.3.0", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/arraybuffer.prototype.slice": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/arraybuffer.prototype.slice/-/arraybuffer.prototype.slice-1.0.4.tgz", + "integrity": "sha512-BNoCY6SXXPQ7gF2opIP4GBE+Xw7U+pHMYKuzjgCN3GwiaIR09UUeKfheyIry77QtrCBlC0KK0q5/TER/tYh3PQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-buffer-byte-length": "^1.0.1", + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "is-array-buffer": "^3.0.4" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/async-function": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/async-function/-/async-function-1.0.0.tgz", + "integrity": "sha512-hsU18Ae8CDTR6Kgu9DYf0EbCr/a5iGL0rytQDobUcdpYOKokk8LEjVphnXkDkgpi0wYVsqrXuP0bZxJaTqdgoA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/available-typed-arrays": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.7.tgz", + "integrity": "sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "possible-typed-array-names": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "dev": true, + "license": "MIT", + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/call-bind": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.8.tgz", + "integrity": "sha512-oKlSFMcMwpUg2ednkhQ454wfWiU/ul3CkJe/PEHcTKuiX6RpbehUiFMXu13HalGZxfUwCQzZG747YXBn1im9ww==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.0", + "es-define-property": "^1.0.0", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, + "license": "MIT" + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true, + "license": "MIT" + }, + "node_modules/crelt": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", + "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==", + "license": "MIT", + "peer": true + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "dev": true, + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/data-view-buffer": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/data-view-buffer/-/data-view-buffer-1.0.2.tgz", + "integrity": "sha512-EmKO5V3OLXh1rtK2wgXRansaK1/mtVdTUEiEI0W8RkvgT05kfxaH29PliLnpLP73yYO6142Q72QNa8Wx/A5CqQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-data-view": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/data-view-byte-length": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/data-view-byte-length/-/data-view-byte-length-1.0.2.tgz", + "integrity": "sha512-tuhGbE6CfTM9+5ANGf+oQb72Ky/0+s3xKUpHvShfiz2RxMFgFPjsXuRLBVMtvMs15awe45SRb83D6wH4ew6wlQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-data-view": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/inspect-js" + } + }, + "node_modules/data-view-byte-offset": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/data-view-byte-offset/-/data-view-byte-offset-1.0.1.tgz", + "integrity": "sha512-BS8PfmtDGnrgYdOonGZQdLZslWIeCGFP9tpan0hi1Co2Zr2NKADsvGYA8XxuG/4UWgJ6Cjtv+YJnB6MM69QGlQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "is-data-view": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/debug": { + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/deep-is": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", + "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/define-properties": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.1.tgz", + "integrity": "sha512-8QmQKqEASLd5nx0U1B1okLElbUuuttJ/AnYmRXbbbGDWh6uS208EjD4Xqq/I9wK7u0v6O08XhTWnt5XtEbR6Dg==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.0.1", + "has-property-descriptors": "^1.0.0", + "object-keys": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/doctrine": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-2.1.0.tgz", + "integrity": "sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "esutils": "^2.0.2" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/empathic": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/empathic/-/empathic-2.0.0.tgz", + "integrity": "sha512-i6UzDscO/XfAcNYD75CfICkmfLedpyPDdozrLMmQc5ORaQcdMoc21OnlEylMIqI7U8eniKrPMxxtj8k0vhmJhA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14" + } + }, + "node_modules/enhanced-resolve": { + "version": "5.18.3", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.3.tgz", + "integrity": "sha512-d4lC8xfavMeBjzGr2vECC3fsGXziXZQyJxD868h2M/mBI3PwAuODxAkLkq5HYuvrPYcUtiLzsTo8U3PgX3Ocww==", + "dev": true, + "license": "MIT", + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/es-abstract": { + "version": "1.24.0", + "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.24.0.tgz", + "integrity": "sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-buffer-byte-length": "^1.0.2", + "arraybuffer.prototype.slice": "^1.0.4", + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "data-view-buffer": "^1.0.2", + "data-view-byte-length": "^1.0.2", + "data-view-byte-offset": "^1.0.1", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "es-set-tostringtag": "^2.1.0", + "es-to-primitive": "^1.3.0", + "function.prototype.name": "^1.1.8", + "get-intrinsic": "^1.3.0", + "get-proto": "^1.0.1", + "get-symbol-description": "^1.1.0", + "globalthis": "^1.0.4", + "gopd": "^1.2.0", + "has-property-descriptors": "^1.0.2", + "has-proto": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "internal-slot": "^1.1.0", + "is-array-buffer": "^3.0.5", + "is-callable": "^1.2.7", + "is-data-view": "^1.0.2", + "is-negative-zero": "^2.0.3", + "is-regex": "^1.2.1", + "is-set": "^2.0.3", + "is-shared-array-buffer": "^1.0.4", + "is-string": "^1.1.1", + "is-typed-array": "^1.1.15", + "is-weakref": "^1.1.1", + "math-intrinsics": "^1.1.0", + "object-inspect": "^1.13.4", + "object-keys": "^1.1.1", + "object.assign": "^4.1.7", + "own-keys": "^1.0.1", + "regexp.prototype.flags": "^1.5.4", + "safe-array-concat": "^1.1.3", + "safe-push-apply": "^1.0.0", + "safe-regex-test": "^1.1.0", + "set-proto": "^1.0.0", + "stop-iteration-iterator": "^1.1.0", + "string.prototype.trim": "^1.2.10", + "string.prototype.trimend": "^1.0.9", + "string.prototype.trimstart": "^1.0.8", + "typed-array-buffer": "^1.0.3", + "typed-array-byte-length": "^1.0.3", + "typed-array-byte-offset": "^1.0.4", + "typed-array-length": "^1.0.7", + "unbox-primitive": "^1.1.0", + "which-typed-array": "^1.1.19" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-iterator-helpers": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.2.1.tgz", + "integrity": "sha512-uDn+FE1yrDzyC0pCo961B2IHbdM8y/ACZsKD4dG6WqrjV53BADjwa7D+1aom2rsNVfLyDgU/eigvlJGJ08OQ4w==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.6", + "es-errors": "^1.3.0", + "es-set-tostringtag": "^2.0.3", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.6", + "globalthis": "^1.0.4", + "gopd": "^1.2.0", + "has-property-descriptors": "^1.0.2", + "has-proto": "^1.2.0", + "has-symbols": "^1.1.0", + "internal-slot": "^1.1.0", + "iterator.prototype": "^1.1.4", + "safe-array-concat": "^1.1.3" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-shim-unscopables": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/es-shim-unscopables/-/es-shim-unscopables-1.1.0.tgz", + "integrity": "sha512-d9T8ucsEhh8Bi1woXCf+TIKDIROLG5WCkxg8geBCbvk22kzwC5G2OnXVMO6FUsvQlgUUXQ2itephWDLqDzbeCw==", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-to-primitive": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-to-primitive/-/es-to-primitive-1.3.0.tgz", + "integrity": "sha512-w+5mJ3GuFL+NjVtJlvydShqE1eN3h3PbI7/5LAsYJP/2qtuMXjfL2LpHSRqo4b4eSF5K/DH1JXKUAHSB2UW50g==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-callable": "^1.2.7", + "is-date-object": "^1.0.5", + "is-symbol": "^1.0.4" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/esbuild": { + "version": "0.25.5", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.5.tgz", + "integrity": "sha512-P8OtKZRv/5J5hhz0cUAdu/cLuPIKXpQl1R9pZtvmHWQvrAUVd0UNIPT4IB4W3rNOqVO0rlqHmCIbSwxh/c9yUQ==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.25.5", + "@esbuild/android-arm": "0.25.5", + "@esbuild/android-arm64": "0.25.5", + "@esbuild/android-x64": "0.25.5", + "@esbuild/darwin-arm64": "0.25.5", + "@esbuild/darwin-x64": "0.25.5", + "@esbuild/freebsd-arm64": "0.25.5", + "@esbuild/freebsd-x64": "0.25.5", + "@esbuild/linux-arm": "0.25.5", + "@esbuild/linux-arm64": "0.25.5", + "@esbuild/linux-ia32": "0.25.5", + "@esbuild/linux-loong64": "0.25.5", + "@esbuild/linux-mips64el": "0.25.5", + "@esbuild/linux-ppc64": "0.25.5", + "@esbuild/linux-riscv64": "0.25.5", + "@esbuild/linux-s390x": "0.25.5", + "@esbuild/linux-x64": "0.25.5", + "@esbuild/netbsd-arm64": "0.25.5", + "@esbuild/netbsd-x64": "0.25.5", + "@esbuild/openbsd-arm64": "0.25.5", + "@esbuild/openbsd-x64": "0.25.5", + "@esbuild/sunos-x64": "0.25.5", + "@esbuild/win32-arm64": "0.25.5", + "@esbuild/win32-ia32": "0.25.5", + "@esbuild/win32-x64": "0.25.5" + } + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint": { + "version": "9.39.1", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.39.1.tgz", + "integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.8.0", + "@eslint-community/regexpp": "^4.12.1", + "@eslint/config-array": "^0.21.1", + "@eslint/config-helpers": "^0.4.2", + "@eslint/core": "^0.17.0", + "@eslint/eslintrc": "^3.3.1", + "@eslint/js": "9.39.1", + "@eslint/plugin-kit": "^0.4.1", + "@humanfs/node": "^0.16.6", + "@humanwhocodes/module-importer": "^1.0.1", + "@humanwhocodes/retry": "^0.4.2", + "@types/estree": "^1.0.6", + "ajv": "^6.12.4", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.6", + "debug": "^4.3.2", + "escape-string-regexp": "^4.0.0", + "eslint-scope": "^8.4.0", + "eslint-visitor-keys": "^4.2.1", + "espree": "^10.4.0", + "esquery": "^1.5.0", + "esutils": "^2.0.2", + "fast-deep-equal": "^3.1.3", + "file-entry-cache": "^8.0.0", + "find-up": "^5.0.0", + "glob-parent": "^6.0.2", + "ignore": "^5.2.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "json-stable-stringify-without-jsonify": "^1.0.1", + "lodash.merge": "^4.6.2", + "minimatch": "^3.1.2", + "natural-compare": "^1.4.0", + "optionator": "^0.9.3" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + }, + "peerDependencies": { + "jiti": "*" + }, + "peerDependenciesMeta": { + "jiti": { + "optional": true + } + } + }, + "node_modules/eslint-compat-utils": { + "version": "0.5.1", + "resolved": "https://registry.npmjs.org/eslint-compat-utils/-/eslint-compat-utils-0.5.1.tgz", + "integrity": "sha512-3z3vFexKIEnjHE3zCMRo6fn/e44U7T1khUjg+Hp0ZQMCigh28rALD0nPFBcGZuiLC5rLZa2ubQHDRln09JfU2Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "semver": "^7.5.4" + }, + "engines": { + "node": ">=12" + }, + "peerDependencies": { + "eslint": ">=6.0.0" + } + }, + "node_modules/eslint-compat-utils/node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/eslint-import-resolver-node": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/eslint-import-resolver-node/-/eslint-import-resolver-node-0.3.9.tgz", + "integrity": "sha512-WFj2isz22JahUv+B788TlO3N6zL3nNJGU8CcZbPZvVEkBPaJdCV4vy5wyghty5ROFbCRnm132v8BScu5/1BQ8g==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "^3.2.7", + "is-core-module": "^2.13.0", + "resolve": "^1.22.4" + } + }, + "node_modules/eslint-import-resolver-node/node_modules/debug": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-3.2.7.tgz", + "integrity": "sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/eslint-module-utils": { + "version": "2.12.1", + "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.12.1.tgz", + "integrity": "sha512-L8jSWTze7K2mTg0vos/RuLRS5soomksDPoJLXIslC7c8Wmut3bx7CPpJijDcBZtxQ5lrbUdM+s0OlNbz0DCDNw==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "^3.2.7" + }, + "engines": { + "node": ">=4" + }, + "peerDependenciesMeta": { + "eslint": { + "optional": true + } + } + }, + "node_modules/eslint-module-utils/node_modules/debug": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-3.2.7.tgz", + "integrity": "sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/eslint-plugin-depend": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-depend/-/eslint-plugin-depend-1.3.1.tgz", + "integrity": "sha512-1uo2rFAr9vzNrCYdp7IBZRB54LiyVxfaIso0R6/QV3t6Dax6DTbW/EV2Hktf0f4UtmGHK8UyzJWI382pwW04jw==", + "dev": true, + "license": "MIT", + "dependencies": { + "empathic": "^2.0.0", + "module-replacements": "^2.8.0", + "semver": "^7.6.3" + } + }, + "node_modules/eslint-plugin-depend/node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/eslint-plugin-es-x": { + "version": "7.8.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-es-x/-/eslint-plugin-es-x-7.8.0.tgz", + "integrity": "sha512-7Ds8+wAAoV3T+LAKeu39Y5BzXCrGKrcISfgKEqTS4BDN8SFEDQd0S43jiQ8vIa3wUKD07qitZdfzlenSi8/0qQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/ota-meshi", + "https://opencollective.com/eslint" + ], + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.1.2", + "@eslint-community/regexpp": "^4.11.0", + "eslint-compat-utils": "^0.5.1" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "peerDependencies": { + "eslint": ">=8" + } + }, + "node_modules/eslint-plugin-import": { + "version": "2.32.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.32.0.tgz", + "integrity": "sha512-whOE1HFo/qJDyX4SnXzP4N6zOWn79WhnCUY/iDR0mPfQZO8wcYE4JClzI2oZrhBnnMUCBCHZhO6VQyoBU95mZA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rtsao/scc": "^1.1.0", + "array-includes": "^3.1.9", + "array.prototype.findlastindex": "^1.2.6", + "array.prototype.flat": "^1.3.3", + "array.prototype.flatmap": "^1.3.3", + "debug": "^3.2.7", + "doctrine": "^2.1.0", + "eslint-import-resolver-node": "^0.3.9", + "eslint-module-utils": "^2.12.1", + "hasown": "^2.0.2", + "is-core-module": "^2.16.1", + "is-glob": "^4.0.3", + "minimatch": "^3.1.2", + "object.fromentries": "^2.0.8", + "object.groupby": "^1.0.3", + "object.values": "^1.2.1", + "semver": "^6.3.1", + "string.prototype.trimend": "^1.0.9", + "tsconfig-paths": "^3.15.0" + }, + "engines": { + "node": ">=4" + }, + "peerDependencies": { + "eslint": "^2 || ^3 || ^4 || ^5 || ^6 || ^7.2.0 || ^8 || ^9" + } + }, + "node_modules/eslint-plugin-import/node_modules/debug": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-3.2.7.tgz", + "integrity": "sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/eslint-plugin-json-schema-validator": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-json-schema-validator/-/eslint-plugin-json-schema-validator-5.1.0.tgz", + "integrity": "sha512-ZmVyxRIjm58oqe2kTuy90PpmZPrrKvOjRPXKzq8WCgRgAkidCgm5X8domL2KSfadZ3QFAmifMgGTcVNhZ5ez2g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.3.0", + "ajv": "^8.0.0", + "debug": "^4.3.1", + "eslint-compat-utils": "^0.5.0", + "json-schema-migrate": "^2.0.0", + "jsonc-eslint-parser": "^2.0.0", + "minimatch": "^8.0.0", + "synckit": "^0.9.0", + "toml-eslint-parser": "^0.9.0", + "tunnel-agent": "^0.6.0", + "yaml-eslint-parser": "^1.0.0" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ota-meshi" + }, + "peerDependencies": { + "eslint": ">=6.0.0" + } + }, + "node_modules/eslint-plugin-json-schema-validator/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/eslint-plugin-json-schema-validator/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/eslint-plugin-json-schema-validator/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true, + "license": "MIT" + }, + "node_modules/eslint-plugin-json-schema-validator/node_modules/minimatch": { + "version": "8.0.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-8.0.4.tgz", + "integrity": "sha512-W0Wvr9HyFXZRGIDgCicunpQ299OKXs9RgZfaukz4qAW/pJhcpUfupc9c+OObPOFueNy8VSrZgEmDtk6Kh4WzDA==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/eslint-plugin-n": { + "version": "17.10.3", + "resolved": "https://registry.npmjs.org/eslint-plugin-n/-/eslint-plugin-n-17.10.3.tgz", + "integrity": "sha512-ySZBfKe49nQZWR1yFaA0v/GsH6Fgp8ah6XV0WDz6CN8WO0ek4McMzb7A2xnf4DCYV43frjCygvb9f/wx7UUxRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.4.0", + "enhanced-resolve": "^5.17.0", + "eslint-plugin-es-x": "^7.5.0", + "get-tsconfig": "^4.7.0", + "globals": "^15.8.0", + "ignore": "^5.2.4", + "minimatch": "^9.0.5", + "semver": "^7.5.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + }, + "peerDependencies": { + "eslint": ">=8.23.0" + } + }, + "node_modules/eslint-plugin-n/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/eslint-plugin-n/node_modules/globals": { + "version": "15.15.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-15.15.0.tgz", + "integrity": "sha512-7ACyT3wmyp3I61S4fG682L0VA2RGD9otkqGJIwNUMF1SWUombIIk+af1unuDYgMm082aHYwD+mzJvv9Iu8dsgg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint-plugin-n/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/eslint-plugin-n/node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/eslint-plugin-obsidianmd": { + "version": "0.1.9", + "resolved": "https://registry.npmjs.org/eslint-plugin-obsidianmd/-/eslint-plugin-obsidianmd-0.1.9.tgz", + "integrity": "sha512-/gyo5vky3Y7re4BtT/8MQbHU5Wes4o6VRqas3YmXE7aTCnMsdV0kfzV1GDXJN9Hrsc9UQPoeKUMiapKL0aGE4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@microsoft/eslint-plugin-sdl": "^1.1.0", + "@types/eslint": "8.56.2", + "@types/node": "20.12.12", + "eslint": ">=9.0.0 <10.0.0", + "eslint-plugin-depend": "1.3.1", + "eslint-plugin-import": "^2.31.0", + "eslint-plugin-json-schema-validator": "5.1.0", + "eslint-plugin-security": "2.1.1", + "globals": "14.0.0", + "obsidian": "1.8.7", + "typescript": "5.4.5" + }, + "bin": { + "eslint-plugin-obsidian": "dist/lib/index.js" + }, + "engines": { + "node": ">= 18" + }, + "peerDependencies": { + "@eslint/js": "^9.30.1", + "@eslint/json": "0.14.0", + "eslint": ">=9.0.0 <10.0.0", + "obsidian": "1.8.7", + "typescript-eslint": "^8.35.1" + } + }, + "node_modules/eslint-plugin-obsidianmd/node_modules/@types/node": { + "version": "20.12.12", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.12.12.tgz", + "integrity": "sha512-eWLDGF/FOSPtAvEqeRAQ4C8LSA7M1I7i0ky1I8U7kD1J5ITyW3AsRhQrKVoWf5pFKZ2kILsEGJhsI9r93PYnOw==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/eslint-plugin-obsidianmd/node_modules/obsidian": { + "version": "1.8.7", + "resolved": "https://registry.npmjs.org/obsidian/-/obsidian-1.8.7.tgz", + "integrity": "sha512-h4bWwNFAGRXlMlMAzdEiIM2ppTGlrh7uGOJS6w4gClrsjc+ei/3YAtU2VdFUlCiPuTHpY4aBpFJJW75S1Tl/JA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/codemirror": "5.60.8", + "moment": "2.29.4" + }, + "peerDependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, + "node_modules/eslint-plugin-obsidianmd/node_modules/typescript": { + "version": "5.4.5", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.4.5.tgz", + "integrity": "sha512-vcI4UpRgg81oIRUFwR0WSIHKt11nJ7SAVlYNIu+QpqeyXP+gpQJy/Z4+F0aGxSE4MqwjyXvW/TzgkLAx2AGHwQ==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/eslint-plugin-react": { + "version": "7.37.3", + "resolved": "https://registry.npmjs.org/eslint-plugin-react/-/eslint-plugin-react-7.37.3.tgz", + "integrity": "sha512-DomWuTQPFYZwF/7c9W2fkKkStqZmBd3uugfqBYLdkZ3Hii23WzZuOLUskGxB8qkSKqftxEeGL1TB2kMhrce0jA==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-includes": "^3.1.8", + "array.prototype.findlast": "^1.2.5", + "array.prototype.flatmap": "^1.3.3", + "array.prototype.tosorted": "^1.1.4", + "doctrine": "^2.1.0", + "es-iterator-helpers": "^1.2.1", + "estraverse": "^5.3.0", + "hasown": "^2.0.2", + "jsx-ast-utils": "^2.4.1 || ^3.0.0", + "minimatch": "^3.1.2", + "object.entries": "^1.1.8", + "object.fromentries": "^2.0.8", + "object.values": "^1.2.1", + "prop-types": "^15.8.1", + "resolve": "^2.0.0-next.5", + "semver": "^6.3.1", + "string.prototype.matchall": "^4.0.12", + "string.prototype.repeat": "^1.0.0" + }, + "engines": { + "node": ">=4" + }, + "peerDependencies": { + "eslint": "^3 || ^4 || ^5 || ^6 || ^7 || ^8 || ^9.7" + } + }, + "node_modules/eslint-plugin-react/node_modules/resolve": { + "version": "2.0.0-next.5", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-2.0.0-next.5.tgz", + "integrity": "sha512-U7WjGVG9sH8tvjW5SmGbQuui75FiyjAX72HX15DwBBwF9dNiQZRQAg9nnPhYy+TUnE0+VcrttuvNI8oSxZcocA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-core-module": "^2.13.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/eslint-plugin-security": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-security/-/eslint-plugin-security-2.1.1.tgz", + "integrity": "sha512-7cspIGj7WTfR3EhaILzAPcfCo5R9FbeWvbgsPYWivSurTBKW88VQxtP3c4aWMG9Hz/GfJlJVdXEJ3c8LqS+u2w==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "safe-regex": "^2.1.1" + } + }, + "node_modules/eslint-scope": { + "version": "8.4.0", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.4.0.tgz", + "integrity": "sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/@eslint/js": { + "version": "9.39.1", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.39.1.tgz", + "integrity": "sha512-S26Stp4zCy88tH94QbBv3XCuzRQiZ9yXofEILmglYTh/Ug/a9/umqvgFtYBAo3Lp0nsI/5/qH1CCrbdK3AP1Tw==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + } + }, + "node_modules/espree": { + "version": "10.4.0", + "resolved": "https://registry.npmjs.org/espree/-/espree-10.4.0.tgz", + "integrity": "sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "acorn": "^8.15.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/esquery": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.6.0.tgz", + "integrity": "sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-glob": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.8" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-uri": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", + "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/fastq": { + "version": "1.19.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.1.tgz", + "integrity": "sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/file-entry-cache": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", + "integrity": "sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "flat-cache": "^4.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "dev": true, + "license": "MIT", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/find-up": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "dev": true, + "license": "MIT", + "dependencies": { + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/flat-cache": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-4.0.1.tgz", + "integrity": "sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==", + "dev": true, + "license": "MIT", + "dependencies": { + "flatted": "^3.2.9", + "keyv": "^4.5.4" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/flatted": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", + "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "dev": true, + "license": "ISC" + }, + "node_modules/for-each": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.5.tgz", + "integrity": "sha512-dKx12eRCVIzqCxFGplyFKJMPvLEWgmNtUrpTiJIR5u97zEhRG8ySrtboPHZXx7daLxQVrl643cTzbab2tkQjxg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-callable": "^1.2.7" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/function.prototype.name": { + "version": "1.1.8", + "resolved": "https://registry.npmjs.org/function.prototype.name/-/function.prototype.name-1.1.8.tgz", + "integrity": "sha512-e5iwyodOHhbMr/yNrc7fDYG4qlbIvI5gajyzPnb5TCwyhjApznQh1BMFou9b30SevY43gCJKXycoCBjMbsuW0Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "functions-have-names": "^1.2.3", + "hasown": "^2.0.2", + "is-callable": "^1.2.7" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/functions-have-names": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/functions-have-names/-/functions-have-names-1.2.3.tgz", + "integrity": "sha512-xckBUXyTIqT97tq2x2AMb+g163b5JFysYk0x4qxNFwbfQkmNZoiRHb6sPzI9/QV33WeuvVYBUIiD4NzNIyqaRQ==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/generator-function": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/generator-function/-/generator-function-2.0.1.tgz", + "integrity": "sha512-SFdFmIJi+ybC0vjlHN0ZGVGHc3lgE0DxPAT0djjVg+kjOnSqclqmj0KQ7ykTOLP6YxoqOvuAODGdcHJn+43q3g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "dev": true, + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/get-symbol-description": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/get-symbol-description/-/get-symbol-description-1.1.0.tgz", + "integrity": "sha512-w9UMqWwJxHNOvoNzSJ2oPF5wvYcvP7jUvYzhp67yEhTi17ZDBBC1z9pTdGuzjD+EFIqLSYRweZjqfiPzQ06Ebg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-tsconfig": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.13.0.tgz", + "integrity": "sha512-1VKTZJCwBrvbd+Wn3AOgQP/2Av+TfTCOlE4AcRJE72W1ksZXbAx8PPBR9RzgTeSPzlPMHrbANMH3LbltH73wxQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "resolve-pkg-maps": "^1.0.0" + }, + "funding": { + "url": "https://github.com/privatenumber/get-tsconfig?sponsor=1" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/globals": { + "version": "14.0.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-14.0.0.tgz", + "integrity": "sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/globalthis": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/globalthis/-/globalthis-1.0.4.tgz", + "integrity": "sha512-DpLKbNU4WylpxJykQujfCcwYWiV/Jhm50Goo0wrVILAv5jOr9d+H+UR3PhSCD2rCCEIg0uc+G+muBTwD54JhDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-properties": "^1.2.1", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", + "dev": true, + "license": "MIT" + }, + "node_modules/has-bigints": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-bigints/-/has-bigints-1.1.0.tgz", + "integrity": "sha512-R3pbpkcIqv2Pm3dUwgjclDRVmWpTJW2DcMzcIhEXEx1oh/CEMObMm3KLmRJOdvhM7o4uQBnwr8pzRK2sJWIqfg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/has-property-descriptors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-proto": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.2.0.tgz", + "integrity": "sha512-KIL7eQPfHQRC8+XluaIw7BHUwwqL19bQn4hzNgdr+1wXoU0KKj6rufu47lhY7KbJR2C6T6+PfyN0Ea7wkSS+qQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/ignore": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/internal-slot": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.1.0.tgz", + "integrity": "sha512-4gd7VpWNQNB4UKKCFFVcp1AVv+FMOgs9NKzjHKusc8jTMhd5eL1NqQqOpE0KzMds804/yHlglp3uxgluOqAPLw==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "hasown": "^2.0.2", + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/is-array-buffer": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.5.tgz", + "integrity": "sha512-DDfANUiiG2wC1qawP66qlTugJeL5HyzMpfr8lLK+jMQirGzNod0B12cFB/9q838Ru27sBwfw78/rdoU7RERz6A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "get-intrinsic": "^1.2.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-async-function": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-async-function/-/is-async-function-2.1.1.tgz", + "integrity": "sha512-9dgM/cZBnNvjzaMYHVoxxfPj2QXt22Ev7SuuPrs+xav0ukGB0S6d4ydZdEiM48kLx5kDV+QBPrpVnFyefL8kkQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "async-function": "^1.0.0", + "call-bound": "^1.0.3", + "get-proto": "^1.0.1", + "has-tostringtag": "^1.0.2", + "safe-regex-test": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-bigint": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-bigint/-/is-bigint-1.1.0.tgz", + "integrity": "sha512-n4ZT37wG78iz03xPRKJrHTdZbe3IicyucEtdRsV5yglwc3GyUfbAfpSeD0FJ41NbUNSt5wbhqfp1fS+BgnvDFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-bigints": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-boolean-object": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/is-boolean-object/-/is-boolean-object-1.2.2.tgz", + "integrity": "sha512-wa56o2/ElJMYqjCjGkXri7it5FbebW5usLw/nPmCMs5DeZ7eziSYZhSmPRn0txqeW4LnAmQQU7FgqLpsEFKM4A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-callable": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/is-callable/-/is-callable-1.2.7.tgz", + "integrity": "sha512-1BC0BVFhS/p0qtw6enp8e+8OD0UrK0oFLztSjNzhcKA3WDuJxxAPXzPuPtKkjEY9UUoEWlX/8fgKeu2S8i9JTA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-data-view": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/is-data-view/-/is-data-view-1.0.2.tgz", + "integrity": "sha512-RKtWF8pGmS87i2D6gqQu/l7EYRlVdfzemCJN/P3UOs//x1QE7mfhvzHIApBTRf7axvT6DMGwSwBXYCT0nfB9xw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "get-intrinsic": "^1.2.6", + "is-typed-array": "^1.1.13" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-date-object": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-date-object/-/is-date-object-1.1.0.tgz", + "integrity": "sha512-PwwhEakHVKTdRNVOw+/Gyh0+MzlCl4R6qKvkhuvLtPMggI1WAHt9sOwZxQLSGpUaDnrdyDsomoRgNnCfKNSXXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-finalizationregistry": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-finalizationregistry/-/is-finalizationregistry-1.1.1.tgz", + "integrity": "sha512-1pC6N8qWJbWoPtEjgcL2xyhQOP491EQjeUo3qTKcmV8YSDDJrOepfG8pcC7h/QgnQHYSv0mJ3Z/ZWxmatVrysg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-generator-function": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/is-generator-function/-/is-generator-function-1.1.2.tgz", + "integrity": "sha512-upqt1SkGkODW9tsGNG5mtXTXtECizwtS2kA161M+gJPc1xdb/Ax629af6YrTwcOeQHbewrPNlE5Dx7kzvXTizA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.4", + "generator-function": "^2.0.0", + "get-proto": "^1.0.1", + "has-tostringtag": "^1.0.2", + "safe-regex-test": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-map": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-map/-/is-map-2.0.3.tgz", + "integrity": "sha512-1Qed0/Hr2m+YqxnM09CjA2d/i6YZNfF6R2oRAOj36eUdS6qIV/huPJNSEpKbupewFs+ZsJlxsjjPbc0/afW6Lw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-negative-zero": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-negative-zero/-/is-negative-zero-2.0.3.tgz", + "integrity": "sha512-5KoIu2Ngpyek75jXodFvnafB6DJgr3u8uuK0LEZJjrU19DrMD3EVERaR8sjz8CCGgpZvxPl9SuE1GMVPFHx1mw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-number-object": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-number-object/-/is-number-object-1.1.1.tgz", + "integrity": "sha512-lZhclumE1G6VYD8VHe35wFaIif+CTy5SJIi5+3y4psDgWu4wPDoBhF8NxUOinEc7pHgiTsT6MaBb92rKhhD+Xw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-regex": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/is-regex/-/is-regex-1.2.1.tgz", + "integrity": "sha512-MjYsKHO5O7mCsmRGxWcLWheFqN9DJ/2TmngvjKXihe6efViPqc274+Fx/4fYj/r03+ESvBdTXK0V6tA3rgez1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "gopd": "^1.2.0", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-set": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-set/-/is-set-2.0.3.tgz", + "integrity": "sha512-iPAjerrse27/ygGLxw+EBR9agv9Y6uLeYVJMu+QNCoouJ1/1ri0mGrcWpfCqFZuzzx3WjtwxG098X+n4OuRkPg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-shared-array-buffer": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/is-shared-array-buffer/-/is-shared-array-buffer-1.0.4.tgz", + "integrity": "sha512-ISWac8drv4ZGfwKl5slpHG9OwPNty4jOWPRIhBpxOoD+hqITiwuipOQ2bNthAzwA3B4fIjO4Nln74N0S9byq8A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-string": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-string/-/is-string-1.1.1.tgz", + "integrity": "sha512-BtEeSsoaQjlSPBemMQIrY1MY0uM6vnS1g5fmufYOtnxLGUZM2178PKbhsk7Ffv58IX+ZtcvoGwccYsh0PglkAA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-symbol": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-symbol/-/is-symbol-1.1.1.tgz", + "integrity": "sha512-9gGx6GTtCQM73BgmHQXfDmLtfjjTUDSyoxTCbp5WtoixAhfgsDirWIcVQ/IHpvI5Vgd5i/J5F7B9cN/WlVbC/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "has-symbols": "^1.1.0", + "safe-regex-test": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-typed-array": { + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.15.tgz", + "integrity": "sha512-p3EcsicXjit7SaskXHs1hA91QxgTw46Fv6EFKKGS5DRFLD8yKnohjF3hxoju94b/OcMZoQukzpPpBE9uLVKzgQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "which-typed-array": "^1.1.16" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-weakmap": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/is-weakmap/-/is-weakmap-2.0.2.tgz", + "integrity": "sha512-K5pXYOm9wqY1RgjpL3YTkF39tni1XajUIkawTLUo9EZEVUFga5gSQJF8nNS7ZwJQ02y+1YCNYcMh+HIf1ZqE+w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-weakref": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-weakref/-/is-weakref-1.1.1.tgz", + "integrity": "sha512-6i9mGWSlqzNMEqpCp93KwRS1uUOodk2OJ6b+sq7ZPDSy2WuI5NFIxp/254TytR8ftefexkWn5xNiHUNpPOfSew==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-weakset": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/is-weakset/-/is-weakset-2.0.4.tgz", + "integrity": "sha512-mfcwb6IzQyOKTs84CQMrOwW4gQcaTOAWJ0zzJCl2WSPDrWk/OzDaImWFH3djXhb24g4eudZfLRozAvPGw4d9hQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "get-intrinsic": "^1.2.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/isarray": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-2.0.5.tgz", + "integrity": "sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==", + "dev": true, + "license": "MIT" + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true, + "license": "ISC" + }, + "node_modules/iterator.prototype": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/iterator.prototype/-/iterator.prototype-1.1.5.tgz", + "integrity": "sha512-H0dkQoCa3b2VEeKQBOxFph+JAbcrQdE7KC0UkqwpLmv2EC4P41QXP+rqo9wYodACiG5/WM5s9oDApTU8utwj9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-object-atoms": "^1.0.0", + "get-intrinsic": "^1.2.6", + "get-proto": "^1.0.0", + "has-symbols": "^1.1.0", + "set-function-name": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/jiti": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.6.1.tgz", + "integrity": "sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==", + "dev": true, + "license": "MIT", + "bin": { + "jiti": "lib/jiti-cli.mjs" + } + }, + "node_modules/js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/js-yaml": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", + "dev": true, + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, + "node_modules/json-buffer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-schema-migrate": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/json-schema-migrate/-/json-schema-migrate-2.0.0.tgz", + "integrity": "sha512-r38SVTtojDRp4eD6WsCqiE0eNDt4v1WalBXb9cyZYw9ai5cGtBwzRNWjHzJl38w6TxFkXAIA7h+fyX3tnrAFhQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ajv": "^8.0.0" + } + }, + "node_modules/json-schema-migrate/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/json-schema-migrate/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/json5": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/json5/-/json5-1.0.2.tgz", + "integrity": "sha512-g1MWMLBiz8FKi1e4w0UyVL3w+iJceWAFBAaBnnGKOpNa5f8TLktkbre1+s6oICydWAm+HRUGTmI+//xv2hvXYA==", + "dev": true, + "license": "MIT", + "dependencies": { + "minimist": "^1.2.0" + }, + "bin": { + "json5": "lib/cli.js" + } + }, + "node_modules/jsonc-eslint-parser": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/jsonc-eslint-parser/-/jsonc-eslint-parser-2.4.1.tgz", + "integrity": "sha512-uuPNLJkKN8NXAlZlQ6kmUF9qO+T6Kyd7oV4+/7yy8Jz6+MZNyhPq8EdLpdfnPVzUC8qSf1b4j1azKaGnFsjmsw==", + "dev": true, + "license": "MIT", + "dependencies": { + "acorn": "^8.5.0", + "eslint-visitor-keys": "^3.0.0", + "espree": "^9.0.0", + "semver": "^7.3.5" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ota-meshi" + } + }, + "node_modules/jsonc-eslint-parser/node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/jsonc-eslint-parser/node_modules/espree": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-9.6.1.tgz", + "integrity": "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "acorn": "^8.9.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^3.4.1" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/jsonc-eslint-parser/node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/jsx-ast-utils": { + "version": "3.3.5", + "resolved": "https://registry.npmjs.org/jsx-ast-utils/-/jsx-ast-utils-3.3.5.tgz", + "integrity": "sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-includes": "^3.1.6", + "array.prototype.flat": "^1.3.1", + "object.assign": "^4.1.4", + "object.values": "^1.1.6" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/keyv": { + "version": "4.5.4", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "json-buffer": "3.0.1" + } + }, + "node_modules/levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/locate-path": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-locate": "^5.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/micromatch": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "dev": true, + "license": "MIT", + "dependencies": { + "braces": "^3.0.3", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/minimist": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", + "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/module-replacements": { + "version": "2.10.1", + "resolved": "https://registry.npmjs.org/module-replacements/-/module-replacements-2.10.1.tgz", + "integrity": "sha512-qkKuLpMHDqRSM676OPL7HUpCiiP3NSxgf8NNR1ga2h/iJLNKTsOSjMEwrcT85DMSti2vmOqxknOVBGWj6H6etQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/moment": { + "version": "2.29.4", + "resolved": "https://registry.npmjs.org/moment/-/moment-2.29.4.tgz", + "integrity": "sha512-5LC9SOxjSc2HF6vO2CyuTDNivEdoz2IvyJJGj6X8DJ0eFyfszE0QiEd+iXmBvUP3WHxSjFH/vIsA0EN00cgr8w==", + "license": "MIT", + "engines": { + "node": "*" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true, + "license": "MIT" + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", + "dev": true, + "license": "MIT" + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-inspect": { + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object-keys": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", + "integrity": "sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/object.assign": { + "version": "4.1.7", + "resolved": "https://registry.npmjs.org/object.assign/-/object.assign-4.1.7.tgz", + "integrity": "sha512-nK28WOo+QIjBkDduTINE4JkF/UJJKyf2EJxvJKfblDpyg0Q+pkOHNTL0Qwy6NP6FhE/EnzV73BxxqcJaXY9anw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0", + "has-symbols": "^1.1.0", + "object-keys": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object.entries": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/object.entries/-/object.entries-1.1.9.tgz", + "integrity": "sha512-8u/hfXFRBD1O0hPUjioLhoWFHRmt6tKA4/vZPyckBr18l1KE9uHrFaFaUi8MDRTpi4uak2goyPTSNJLXX2k2Hw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/object.fromentries": { + "version": "2.0.8", + "resolved": "https://registry.npmjs.org/object.fromentries/-/object.fromentries-2.0.8.tgz", + "integrity": "sha512-k6E21FzySsSK5a21KRADBd/NGneRegFO5pLHfdQLpRDETUNJueLXs3WCzyQ3tFRDYgbq3KHGXfTbi2bs8WQ6rQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.2", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object.groupby": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/object.groupby/-/object.groupby-1.0.3.tgz", + "integrity": "sha512-+Lhy3TQTuzXI5hevh8sBGqbmurHbbIjAi0Z4S63nthVLmLxfbj4T54a4CfZrXIrt9iP4mVAPYMo/v99taj3wjQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/object.values": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/object.values/-/object.values-1.2.1.tgz", + "integrity": "sha512-gXah6aZrcUxjWg2zR2MwouP2eHlCBzdV4pygudehaKXSGW4v2AsRQUK+lwwXhii6KFZcunEnmSUoYp5CXibxtA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/obsidian": { + "version": "1.10.3", + "resolved": "https://registry.npmjs.org/obsidian/-/obsidian-1.10.3.tgz", + "integrity": "sha512-VP+ZSxNMG7y6Z+sU9WqLvJAskCfkFrTz2kFHWmmzis+C+4+ELjk/sazwcTHrHXNZlgCeo8YOlM6SOrAFCynNew==", + "license": "MIT", + "dependencies": { + "@types/codemirror": "5.60.8", + "moment": "2.29.4" + }, + "peerDependencies": { + "@codemirror/state": "6.5.0", + "@codemirror/view": "6.38.6" + } + }, + "node_modules/optionator": { + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", + "integrity": "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.5" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/own-keys": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/own-keys/-/own-keys-1.0.1.tgz", + "integrity": "sha512-qFOyK5PjiWZd+QQIh+1jhdb9LpxTF0qs7Pm8o5QHYZ0M3vKqSqzsZaEB6oWlxZ+q2sJBMI/Ktgd2N5ZwQoRHfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "get-intrinsic": "^1.2.6", + "object-keys": "^1.1.1", + "safe-push-apply": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-limit": "^3.0.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true, + "license": "MIT" + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/possible-typed-array-names": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.1.0.tgz", + "integrity": "sha512-/+5VFTchJDoVj3bhoqi6UeymcD00DAwb1nJwamzPvHEszJ4FpF6SNNbUbOS8yI56qHzdV8eK0qEfOSiodkTdxg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "dev": true, + "license": "MIT", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/reflect.getprototypeof": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz", + "integrity": "sha512-00o4I+DVrefhv+nX0ulyi3biSHCPDe+yLv5o/p6d/UVlirijB8E16FtfwSAi4g3tcqrQ4lRAqQSoFEZJehYEcw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.9", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0", + "get-intrinsic": "^1.2.7", + "get-proto": "^1.0.1", + "which-builtin-type": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/regexp-tree": { + "version": "0.1.27", + "resolved": "https://registry.npmjs.org/regexp-tree/-/regexp-tree-0.1.27.tgz", + "integrity": "sha512-iETxpjK6YoRWJG5o6hXLwvjYAoW+FEZn9os0PD/b6AP6xQwsa/Y7lCVgIixBbUPMfhu+i2LtdeAqVTgGlQarfA==", + "dev": true, + "license": "MIT", + "bin": { + "regexp-tree": "bin/regexp-tree" + } + }, + "node_modules/regexp.prototype.flags": { + "version": "1.5.4", + "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.5.4.tgz", + "integrity": "sha512-dYqgNSZbDwkaJ2ceRd9ojCGjBq+mOm9LmtXnAnEGyHhN/5R7iDW2TRw3h+o/jCFxus3P2LfWIIiwowAjANm7IA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-errors": "^1.3.0", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "set-function-name": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/resolve": { + "version": "1.22.11", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", + "integrity": "sha512-RfqAvLnMl313r7c9oclB1HhUEAezcpLjz95wFH4LVuhk9JF/r22qmVP9AMmOU4vMX7Q8pN8jwNg/CSpdFnMjTQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-core-module": "^2.16.1", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/resolve-pkg-maps": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz", + "integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" + } + }, + "node_modules/ret": { + "version": "0.1.15", + "resolved": "https://registry.npmjs.org/ret/-/ret-0.1.15.tgz", + "integrity": "sha512-TTlYpa+OL+vMMNG24xSlQGEJ3B/RzEfUlLct7b5G/ytav+wPrplCpVMFuwzXbkecJrb6IYo1iFb0S9v37754mg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.12" + } + }, + "node_modules/reusify": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", + "integrity": "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==", + "dev": true, + "license": "MIT", + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/safe-array-concat": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/safe-array-concat/-/safe-array-concat-1.1.3.tgz", + "integrity": "sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.2", + "get-intrinsic": "^1.2.6", + "has-symbols": "^1.1.0", + "isarray": "^2.0.5" + }, + "engines": { + "node": ">=0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/safe-push-apply": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/safe-push-apply/-/safe-push-apply-1.0.0.tgz", + "integrity": "sha512-iKE9w/Z7xCzUMIZqdBsp6pEQvwuEebH4vdpjcDWnyzaI6yl6O9FHvVpmGelvEHNsoY6wGblkxR6Zty/h00WiSA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "isarray": "^2.0.5" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/safe-regex": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/safe-regex/-/safe-regex-2.1.1.tgz", + "integrity": "sha512-rx+x8AMzKb5Q5lQ95Zoi6ZbJqwCLkqi3XuJXp5P3rT8OEc6sZCJG5AE5dU3lsgRr/F4Bs31jSlVN+j5KrsGu9A==", + "dev": true, + "license": "MIT", + "dependencies": { + "regexp-tree": "~0.1.1" + } + }, + "node_modules/safe-regex-test": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/safe-regex-test/-/safe-regex-test-1.1.0.tgz", + "integrity": "sha512-x/+Cz4YrimQxQccJf5mKEbIa1NzeCRNI5Ecl/ekmlYaampdNLPalVyIcCZNNH3MvmqBugV5TMYZXv0ljslUlaw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "is-regex": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/set-function-name": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/set-function-name/-/set-function-name-2.0.2.tgz", + "integrity": "sha512-7PGFlmtwsEADb0WYyvCMa1t+yke6daIG4Wirafur5kcf+MhUnPms1UeR0CKQdTZD81yESwMHbtn+TR+dMviakQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "functions-have-names": "^1.2.3", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/set-proto": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/set-proto/-/set-proto-1.0.0.tgz", + "integrity": "sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==", + "dev": true, + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/side-channel": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/stop-iteration-iterator": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/stop-iteration-iterator/-/stop-iteration-iterator-1.1.0.tgz", + "integrity": "sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "internal-slot": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/string.prototype.matchall": { + "version": "4.0.12", + "resolved": "https://registry.npmjs.org/string.prototype.matchall/-/string.prototype.matchall-4.0.12.tgz", + "integrity": "sha512-6CC9uyBL+/48dYizRf7H7VAYCMCNTBeM78x/VTUe9bFEaxBepPJDa1Ow99LqI/1yF7kuy7Q3cQsYMrcjGUcskA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.6", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0", + "get-intrinsic": "^1.2.6", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "internal-slot": "^1.1.0", + "regexp.prototype.flags": "^1.5.3", + "set-function-name": "^2.0.2", + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/string.prototype.repeat": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/string.prototype.repeat/-/string.prototype.repeat-1.0.0.tgz", + "integrity": "sha512-0u/TldDbKD8bFCQ/4f5+mNRrXwZ8hg2w7ZR8wa16e8z9XpePWl3eGEcUD0OXpEH/VJH/2G3gjUtR3ZOiBe2S/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-properties": "^1.1.3", + "es-abstract": "^1.17.5" + } + }, + "node_modules/string.prototype.trim": { + "version": "1.2.10", + "resolved": "https://registry.npmjs.org/string.prototype.trim/-/string.prototype.trim-1.2.10.tgz", + "integrity": "sha512-Rs66F0P/1kedk5lyYyH9uBzuiI/kNRmwJAR9quK6VOtIpZ2G+hMZd+HQbbv25MgCA6gEffoMZYxlTod4WcdrKA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.2", + "define-data-property": "^1.1.4", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-object-atoms": "^1.0.0", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/string.prototype.trimend": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/string.prototype.trimend/-/string.prototype.trimend-1.0.9.tgz", + "integrity": "sha512-G7Ok5C6E/j4SGfyLCloXTrngQIQU3PWtXGst3yM7Bea9FRURf1S42ZHlZZtsNque2FN2PoUhfZXYLNWwEr4dLQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.2", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/string.prototype.trimstart": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/string.prototype.trimstart/-/string.prototype.trimstart-1.0.8.tgz", + "integrity": "sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/style-mod": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz", + "integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==", + "license": "MIT", + "peer": true + }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/synckit": { + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/synckit/-/synckit-0.9.3.tgz", + "integrity": "sha512-JJoOEKTfL1urb1mDoEblhD9NhEbWmq9jHEMEnxoC4ujUaZ4itA8vKgwkFAyNClgxplLi9tsUKX+EduK0p/l7sg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@pkgr/core": "^0.1.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/unts" + } + }, + "node_modules/synckit/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD" + }, + "node_modules/tapable": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/toml-eslint-parser": { + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/toml-eslint-parser/-/toml-eslint-parser-0.9.3.tgz", + "integrity": "sha512-moYoCvkNUAPCxSW9jmHmRElhm4tVJpHL8ItC/+uYD0EpPSFXbck7yREz9tNdJVTSpHVod8+HoipcpbQ0oE6gsw==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.0.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ota-meshi" + } + }, + "node_modules/toml-eslint-parser/node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/ts-api-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.1.0.tgz", + "integrity": "sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18.12" + }, + "peerDependencies": { + "typescript": ">=4.8.4" + } + }, + "node_modules/tsconfig-paths": { + "version": "3.15.0", + "resolved": "https://registry.npmjs.org/tsconfig-paths/-/tsconfig-paths-3.15.0.tgz", + "integrity": "sha512-2Ac2RgzDe/cn48GvOe3M+o82pEFewD3UPbyoUHHdKasHwJKjds4fLXWf/Ux5kATBKN20oaFGu+jbElp1pos0mg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/json5": "^0.0.29", + "json5": "^1.0.2", + "minimist": "^1.2.6", + "strip-bom": "^3.0.0" + } + }, + "node_modules/tslib": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.4.0.tgz", + "integrity": "sha512-d6xOpEDfsi2CZVlPQzGeux8XMwLT9hssAsaPYExaQMuYskwb+x1x7J371tWlbBdWHroy99KnVB6qIkUbs5X3UQ==", + "dev": true, + "license": "0BSD" + }, + "node_modules/tunnel-agent": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz", + "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "safe-buffer": "^5.0.1" + }, + "engines": { + "node": "*" + } + }, + "node_modules/type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/typed-array-buffer": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/typed-array-buffer/-/typed-array-buffer-1.0.3.tgz", + "integrity": "sha512-nAYYwfY3qnzX30IkA6AQZjVbtK6duGontcQm1WSG1MD94YLqK0515GNApXkoxKOWMusVssAHWLh9SeaoefYFGw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-typed-array": "^1.1.14" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/typed-array-byte-length": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/typed-array-byte-length/-/typed-array-byte-length-1.0.3.tgz", + "integrity": "sha512-BaXgOuIxz8n8pIq3e7Atg/7s+DpiYrxn4vdot3w9KbnBhcRQq6o3xemQdIfynqSeXeDrF32x+WvfzmOjPiY9lg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "for-each": "^0.3.3", + "gopd": "^1.2.0", + "has-proto": "^1.2.0", + "is-typed-array": "^1.1.14" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typed-array-byte-offset": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/typed-array-byte-offset/-/typed-array-byte-offset-1.0.4.tgz", + "integrity": "sha512-bTlAFB/FBYMcuX81gbL4OcpH5PmlFHqlCCpAl8AlEzMz5k53oNDvN8p1PNOWLEmI2x4orp3raOFB51tv9X+MFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "for-each": "^0.3.3", + "gopd": "^1.2.0", + "has-proto": "^1.2.0", + "is-typed-array": "^1.1.15", + "reflect.getprototypeof": "^1.0.9" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typed-array-length": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/typed-array-length/-/typed-array-length-1.0.7.tgz", + "integrity": "sha512-3KS2b+kL7fsuk/eJZ7EQdnEmQoaho/r6KUef7hxvltNA5DR8NAUM+8wJMbJyZ4G9/7i3v5zPBIMN5aybAh2/Jg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "for-each": "^0.3.3", + "gopd": "^1.0.1", + "is-typed-array": "^1.1.13", + "possible-typed-array-names": "^1.0.0", + "reflect.getprototypeof": "^1.0.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typescript": { + "version": "5.8.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.8.3.tgz", + "integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/typescript-eslint": { + "version": "8.35.1", + "resolved": "https://registry.npmjs.org/typescript-eslint/-/typescript-eslint-8.35.1.tgz", + "integrity": "sha512-xslJjFzhOmHYQzSB/QTeASAHbjmxOGEP6Coh93TXmUBFQoJ1VU35UHIDmG06Jd6taf3wqqC1ntBnCMeymy5Ovw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/eslint-plugin": "8.35.1", + "@typescript-eslint/parser": "8.35.1", + "@typescript-eslint/utils": "8.35.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <5.9.0" + } + }, + "node_modules/unbox-primitive": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.1.0.tgz", + "integrity": "sha512-nWJ91DjeOkej/TA8pXQ3myruKpKEYgqvpw9lz4OPHj/NWFNluYrjbz9j01CJ8yKQd2g4jFoOkINCTW2I5LEEyw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-bigints": "^1.0.2", + "has-symbols": "^1.1.0", + "which-boxed-primitive": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "dev": true, + "license": "MIT" + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", + "license": "MIT", + "peer": true + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/which-boxed-primitive": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/which-boxed-primitive/-/which-boxed-primitive-1.1.1.tgz", + "integrity": "sha512-TbX3mj8n0odCBFVlY8AxkqcHASw3L60jIuF8jFP78az3C2YhmGvqbHBpAjTRH2/xqYunrJ9g1jSyjCjpoWzIAA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-bigint": "^1.1.0", + "is-boolean-object": "^1.2.1", + "is-number-object": "^1.1.1", + "is-string": "^1.1.1", + "is-symbol": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/which-builtin-type": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/which-builtin-type/-/which-builtin-type-1.2.1.tgz", + "integrity": "sha512-6iBczoX+kDQ7a3+YJBnh3T+KZRxM/iYNPXicqk66/Qfm1b93iu+yOImkg0zHbj5LNOcNv1TEADiZ0xa34B4q6Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "function.prototype.name": "^1.1.6", + "has-tostringtag": "^1.0.2", + "is-async-function": "^2.0.0", + "is-date-object": "^1.1.0", + "is-finalizationregistry": "^1.1.0", + "is-generator-function": "^1.0.10", + "is-regex": "^1.2.1", + "is-weakref": "^1.0.2", + "isarray": "^2.0.5", + "which-boxed-primitive": "^1.1.0", + "which-collection": "^1.0.2", + "which-typed-array": "^1.1.16" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/which-collection": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/which-collection/-/which-collection-1.0.2.tgz", + "integrity": "sha512-K4jVyjnBdgvc86Y6BkaLZEN933SwYOuBFkdmBu9ZfkcAbdVbpITnDmjvZ/aQjRXQrv5EPkTnD1s39GiiqbngCw==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-map": "^2.0.3", + "is-set": "^2.0.3", + "is-weakmap": "^2.0.2", + "is-weakset": "^2.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/which-typed-array": { + "version": "1.1.19", + "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.19.tgz", + "integrity": "sha512-rEvr90Bck4WZt9HHFC4DJMsjvu7x+r6bImz0/BrbWb7A2djJ8hnZMrWnHo9F8ssv0OMErasDhftrfROTyqSDrw==", + "dev": true, + "license": "MIT", + "dependencies": { + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "for-each": "^0.3.5", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/word-wrap": { + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", + "integrity": "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/yaml": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.1.tgz", + "integrity": "sha512-lcYcMxX2PO9XMGvAJkJ3OsNMw+/7FKes7/hgerGUYWIoWu5j/+YQqcZr5JnPZWzOsEBgMbSbiSTn/dv/69Mkpw==", + "dev": true, + "license": "ISC", + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14.6" + } + }, + "node_modules/yaml-eslint-parser": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/yaml-eslint-parser/-/yaml-eslint-parser-1.3.0.tgz", + "integrity": "sha512-E/+VitOorXSLiAqtTd7Yqax0/pAS3xaYMP+AUUJGOK1OZG3rhcj9fcJOM5HJ2VrP1FrStVCWr1muTfQCdj4tAA==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.0.0", + "yaml": "^2.0.0" + }, + "engines": { + "node": "^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ota-meshi" + } + }, + "node_modules/yaml-eslint-parser/node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + } + } +} diff --git a/surfsense_obsidian/package.json b/surfsense_obsidian/package.json new file mode 100644 index 000000000..17268d72a --- /dev/null +++ b/surfsense_obsidian/package.json @@ -0,0 +1,29 @@ +{ + "name": "obsidian-sample-plugin", + "version": "1.0.0", + "description": "This is a sample plugin for Obsidian (https://obsidian.md)", + "main": "main.js", + "type": "module", + "scripts": { + "dev": "node esbuild.config.mjs", + "build": "tsc -noEmit -skipLibCheck && node esbuild.config.mjs production", + "version": "node version-bump.mjs && git add manifest.json versions.json", + "lint": "eslint ." + }, + "keywords": [], + "license": "0-BSD", + "devDependencies": { + "@types/node": "^16.11.6", + "esbuild": "0.25.5", + "eslint-plugin-obsidianmd": "0.1.9", + "globals": "14.0.0", + "tslib": "2.4.0", + "typescript": "^5.8.3", + "typescript-eslint": "8.35.1", + "@eslint/js": "9.30.1", + "jiti": "2.6.1" + }, + "dependencies": { + "obsidian": "latest" + } +} diff --git a/surfsense_obsidian/src/main.ts b/surfsense_obsidian/src/main.ts new file mode 100644 index 000000000..6fe0c83a8 --- /dev/null +++ b/surfsense_obsidian/src/main.ts @@ -0,0 +1,99 @@ +import {App, Editor, MarkdownView, Modal, Notice, Plugin} from 'obsidian'; +import {DEFAULT_SETTINGS, MyPluginSettings, SampleSettingTab} from "./settings"; + +// Remember to rename these classes and interfaces! + +export default class MyPlugin extends Plugin { + settings: MyPluginSettings; + + async onload() { + await this.loadSettings(); + + // This creates an icon in the left ribbon. + this.addRibbonIcon('dice', 'Sample', (evt: MouseEvent) => { + // Called when the user clicks the icon. + new Notice('This is a notice!'); + }); + + // This adds a status bar item to the bottom of the app. Does not work on mobile apps. + const statusBarItemEl = this.addStatusBarItem(); + statusBarItemEl.setText('Status bar text'); + + // This adds a simple command that can be triggered anywhere + this.addCommand({ + id: 'open-modal-simple', + name: 'Open modal (simple)', + callback: () => { + new SampleModal(this.app).open(); + } + }); + // This adds an editor command that can perform some operation on the current editor instance + this.addCommand({ + id: 'replace-selected', + name: 'Replace selected content', + editorCallback: (editor: Editor, view: MarkdownView) => { + editor.replaceSelection('Sample editor command'); + } + }); + // This adds a complex command that can check whether the current state of the app allows execution of the command + this.addCommand({ + id: 'open-modal-complex', + name: 'Open modal (complex)', + checkCallback: (checking: boolean) => { + // Conditions to check + const markdownView = this.app.workspace.getActiveViewOfType(MarkdownView); + if (markdownView) { + // If checking is true, we're simply "checking" if the command can be run. + // If checking is false, then we want to actually perform the operation. + if (!checking) { + new SampleModal(this.app).open(); + } + + // This command will only show up in Command Palette when the check function returns true + return true; + } + return false; + } + }); + + // This adds a settings tab so the user can configure various aspects of the plugin + this.addSettingTab(new SampleSettingTab(this.app, this)); + + // If the plugin hooks up any global DOM events (on parts of the app that doesn't belong to this plugin) + // Using this function will automatically remove the event listener when this plugin is disabled. + this.registerDomEvent(document, 'click', (evt: MouseEvent) => { + new Notice("Click"); + }); + + // When registering intervals, this function will automatically clear the interval when the plugin is disabled. + this.registerInterval(window.setInterval(() => console.log('setInterval'), 5 * 60 * 1000)); + + } + + onunload() { + } + + async loadSettings() { + this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData() as Partial); + } + + async saveSettings() { + await this.saveData(this.settings); + } +} + +class SampleModal extends Modal { + constructor(app: App) { + super(app); + } + + onOpen() { + let {contentEl} = this; + contentEl.setText('Woah!'); + } + + onClose() { + const {contentEl} = this; + contentEl.empty(); + } +} diff --git a/surfsense_obsidian/src/settings.ts b/surfsense_obsidian/src/settings.ts new file mode 100644 index 000000000..352121e07 --- /dev/null +++ b/surfsense_obsidian/src/settings.ts @@ -0,0 +1,36 @@ +import {App, PluginSettingTab, Setting} from "obsidian"; +import MyPlugin from "./main"; + +export interface MyPluginSettings { + mySetting: string; +} + +export const DEFAULT_SETTINGS: MyPluginSettings = { + mySetting: 'default' +} + +export class SampleSettingTab extends PluginSettingTab { + plugin: MyPlugin; + + constructor(app: App, plugin: MyPlugin) { + super(app, plugin); + this.plugin = plugin; + } + + display(): void { + const {containerEl} = this; + + containerEl.empty(); + + new Setting(containerEl) + .setName('Settings #1') + .setDesc('It\'s a secret') + .addText(text => text + .setPlaceholder('Enter your secret') + .setValue(this.plugin.settings.mySetting) + .onChange(async (value) => { + this.plugin.settings.mySetting = value; + await this.plugin.saveSettings(); + })); + } +} diff --git a/surfsense_obsidian/styles.css b/surfsense_obsidian/styles.css new file mode 100644 index 000000000..71cc60fd4 --- /dev/null +++ b/surfsense_obsidian/styles.css @@ -0,0 +1,8 @@ +/* + +This CSS file will be included with your plugin, and +available in the app when your plugin is enabled. + +If your plugin does not need CSS, delete this file. + +*/ diff --git a/surfsense_obsidian/tsconfig.json b/surfsense_obsidian/tsconfig.json new file mode 100644 index 000000000..222535dee --- /dev/null +++ b/surfsense_obsidian/tsconfig.json @@ -0,0 +1,30 @@ +{ + "compilerOptions": { + "baseUrl": "src", + "inlineSourceMap": true, + "inlineSources": true, + "module": "ESNext", + "target": "ES6", + "allowJs": true, + "noImplicitAny": true, + "noImplicitThis": true, + "noImplicitReturns": true, + "moduleResolution": "node", + "importHelpers": true, + "noUncheckedIndexedAccess": true, + "isolatedModules": true, + "strictNullChecks": true, + "strictBindCallApply": true, + "allowSyntheticDefaultImports": true, + "useUnknownInCatchVariables": true, + "lib": [ + "DOM", + "ES5", + "ES6", + "ES7" + ] + }, + "include": [ + "src/**/*.ts" + ] +} diff --git a/surfsense_obsidian/version-bump.mjs b/surfsense_obsidian/version-bump.mjs new file mode 100644 index 000000000..55d631fb6 --- /dev/null +++ b/surfsense_obsidian/version-bump.mjs @@ -0,0 +1,17 @@ +import { readFileSync, writeFileSync } from "fs"; + +const targetVersion = process.env.npm_package_version; + +// read minAppVersion from manifest.json and bump version to target version +const manifest = JSON.parse(readFileSync("manifest.json", "utf8")); +const { minAppVersion } = manifest; +manifest.version = targetVersion; +writeFileSync("manifest.json", JSON.stringify(manifest, null, "\t")); + +// update versions.json with target version and minAppVersion from manifest.json +// but only if the target version is not already in versions.json +const versions = JSON.parse(readFileSync('versions.json', 'utf8')); +if (!Object.values(versions).includes(minAppVersion)) { + versions[targetVersion] = minAppVersion; + writeFileSync('versions.json', JSON.stringify(versions, null, '\t')); +} diff --git a/surfsense_obsidian/versions.json b/surfsense_obsidian/versions.json new file mode 100644 index 000000000..26382a157 --- /dev/null +++ b/surfsense_obsidian/versions.json @@ -0,0 +1,3 @@ +{ + "1.0.0": "0.15.0" +} From f903bcc80d0f48723ad2e171e8f22667f8879280 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 20 Apr 2026 04:02:54 +0530 Subject: [PATCH 003/299] feat: add GitHub workflows for linting and releasing Obsidian plugin --- .github/workflows/obsidian-plugin-lint.yml | 44 ++++++++ .github/workflows/release-obsidian-plugin.yml | 102 ++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 .github/workflows/obsidian-plugin-lint.yml create mode 100644 .github/workflows/release-obsidian-plugin.yml diff --git a/.github/workflows/obsidian-plugin-lint.yml b/.github/workflows/obsidian-plugin-lint.yml new file mode 100644 index 000000000..237087d39 --- /dev/null +++ b/.github/workflows/obsidian-plugin-lint.yml @@ -0,0 +1,44 @@ +name: Obsidian Plugin Lint + +# Lints + type-checks + builds the Obsidian plugin on every push/PR that +# touches its sources. The official obsidian-sample-plugin template ships +# its own ESLint+esbuild setup; we run that here instead of folding the +# plugin into the monorepo's Biome-based code-quality.yml so the tooling +# stays aligned with what `obsidianmd/eslint-plugin-obsidianmd` checks +# against. + +on: + push: + branches: ["**"] + paths: + - "surfsense_obsidian/**" + - ".github/workflows/obsidian-plugin-lint.yml" + pull_request: + branches: ["**"] + paths: + - "surfsense_obsidian/**" + - ".github/workflows/obsidian-plugin-lint.yml" + +jobs: + lint: + runs-on: ubuntu-latest + defaults: + run: + working-directory: surfsense_obsidian + strategy: + fail-fast: false + matrix: + node-version: [20.x, 22.x] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-node@v4 + with: + node-version: ${{ matrix.node-version }} + cache: npm + cache-dependency-path: surfsense_obsidian/package-lock.json + + - run: npm ci + - run: npm run lint + - run: npm run build diff --git a/.github/workflows/release-obsidian-plugin.yml b/.github/workflows/release-obsidian-plugin.yml new file mode 100644 index 000000000..c97d45023 --- /dev/null +++ b/.github/workflows/release-obsidian-plugin.yml @@ -0,0 +1,102 @@ +name: Release Obsidian Plugin + +# Triggered on tags of the form `obsidian-v0.1.0`. The version after the +# prefix MUST exactly equal `surfsense_obsidian/manifest.json`'s `version` +# (no leading `v`) — this is what BRAT and the Obsidian community plugin +# store both verify. +on: + push: + tags: + - "obsidian-v*" + workflow_dispatch: + inputs: + tag: + description: "Tag to build (e.g. obsidian-v0.1.0). Dry-run only when run manually." + required: true + default: "obsidian-v0.0.0-test" + +permissions: + contents: write + +jobs: + build-and-release: + runs-on: ubuntu-latest + defaults: + run: + working-directory: surfsense_obsidian + + steps: + - uses: actions/checkout@v4 + with: + # Need write access for the manifest/versions.json mirror commit + # back to main further down. + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - uses: actions/setup-node@v4 + with: + node-version: 20.x + cache: npm + cache-dependency-path: surfsense_obsidian/package-lock.json + + - name: Resolve plugin version + id: version + run: | + tag="${GITHUB_REF_NAME:-${{ github.event.inputs.tag }}}" + version="${tag#obsidian-v}" + manifest_version=$(node -p "require('./manifest.json').version") + if [ "$version" != "$manifest_version" ]; then + echo "::error::Tag version '$version' does not match manifest version '$manifest_version'" + exit 1 + fi + echo "tag=$tag" >> "$GITHUB_OUTPUT" + echo "version=$version" >> "$GITHUB_OUTPUT" + + - run: npm ci + + - run: npm run lint + + - run: npm run build + + - name: Verify build artifacts + run: | + for f in main.js manifest.json styles.css; do + test -f "$f" || (echo "::error::Missing release artifact: $f" && exit 1) + done + + - name: Mirror manifest.json + versions.json to repo root + if: github.event_name == 'push' + working-directory: ${{ github.workspace }} + run: | + cp surfsense_obsidian/manifest.json manifest.json + cp surfsense_obsidian/versions.json versions.json + if git diff --quiet manifest.json versions.json; then + echo "Root manifest/versions already up to date." + exit 0 + fi + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add manifest.json versions.json + git commit -m "chore(obsidian-plugin): mirror manifest+versions for ${{ steps.version.outputs.tag }}" + # Push to the default branch so Obsidian can fetch raw files from HEAD. + git push origin HEAD:${{ github.event.repository.default_branch }} + + # IMPORTANT: BRAT and the Obsidian community plugin store look up the + # release by the bare manifest `version` (e.g. `0.1.0`), NOT by the + # build-trigger tag (`obsidian-v0.1.0`). So we publish the GitHub + # release with `tag_name: ` — `softprops/action-gh-release` + # will create that tag if it doesn't already exist, pointing at the + # commit referenced by the build-trigger tag. Verified against + # https://github.com/khoj-ai/khoj/releases (their tags are bare + # versions like `2.0.0-beta.28`, no prefix). + - name: Create GitHub release + if: github.event_name == 'push' + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.version.outputs.version }} + name: SurfSense Obsidian Plugin ${{ steps.version.outputs.version }} + generate_release_notes: true + files: | + surfsense_obsidian/main.js + surfsense_obsidian/manifest.json + surfsense_obsidian/styles.css From e8fc1069bcf6316a561a270cc4136ef7f32298b2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 20 Apr 2026 04:03:19 +0530 Subject: [PATCH 004/299] feat: implement Obsidian plugin ingestion routes and indexing service --- ...9_deactivate_legacy_obsidian_connectors.py | 75 +++ surfsense_backend/app/routes/__init__.py | 2 + .../app/routes/obsidian_plugin_routes.py | 450 ++++++++++++++++++ .../app/schemas/obsidian_plugin.py | 147 ++++++ .../app/services/obsidian_plugin_indexer.py | 400 ++++++++++++++++ 5 files changed, 1074 insertions(+) create mode 100644 surfsense_backend/alembic/versions/129_deactivate_legacy_obsidian_connectors.py create mode 100644 surfsense_backend/app/routes/obsidian_plugin_routes.py create mode 100644 surfsense_backend/app/schemas/obsidian_plugin.py create mode 100644 surfsense_backend/app/services/obsidian_plugin_indexer.py diff --git a/surfsense_backend/alembic/versions/129_deactivate_legacy_obsidian_connectors.py b/surfsense_backend/alembic/versions/129_deactivate_legacy_obsidian_connectors.py new file mode 100644 index 000000000..42808b1ca --- /dev/null +++ b/surfsense_backend/alembic/versions/129_deactivate_legacy_obsidian_connectors.py @@ -0,0 +1,75 @@ +"""129_deactivate_legacy_obsidian_connectors + +Revision ID: 129 +Revises: 128 +Create Date: 2026-04-18 + +Marks every pre-plugin OBSIDIAN_CONNECTOR row as legacy. We keep the +rows (and their indexed Documents) so existing search results don't +suddenly disappear, but we: + +* set ``is_indexable = false`` and ``periodic_indexing_enabled = false`` + so the scheduler will never fire a server-side scan again, +* clear ``next_scheduled_at`` so the scheduler stops considering the + row, +* merge ``{"legacy": true, "deactivated_at": ""}`` into ``config`` + so the new ObsidianConfig view in the web UI can render the + migration banner (and so a future cleanup script can find them). + +A row is "pre-plugin" when its ``config`` does not already have +``source = "plugin"``. The new plugin indexer always writes +``config.source = "plugin"`` on first /obsidian/connect, so this +predicate is stable. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "129" +down_revision: str | None = "128" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + conn = op.get_bind() + conn.execute( + sa.text( + """ + UPDATE search_source_connectors + SET + is_indexable = false, + periodic_indexing_enabled = false, + next_scheduled_at = NULL, + config = COALESCE(config, '{}'::json)::jsonb + || jsonb_build_object( + 'legacy', true, + 'deactivated_at', to_char( + now() AT TIME ZONE 'UTC', + 'YYYY-MM-DD"T"HH24:MI:SS"Z"' + ) + ) + WHERE connector_type = 'OBSIDIAN_CONNECTOR' + AND COALESCE((config::jsonb)->>'source', '') <> 'plugin' + """ + ) + ) + + +def downgrade() -> None: + conn = op.get_bind() + conn.execute( + sa.text( + """ + UPDATE search_source_connectors + SET config = (config::jsonb - 'legacy' - 'deactivated_at')::json + WHERE connector_type = 'OBSIDIAN_CONNECTOR' + AND (config::jsonb) ? 'legacy' + """ + ) + ) diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index ad40666cd..070060878 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -37,6 +37,7 @@ from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router from .notifications_routes import router as notifications_router from .notion_add_connector_route import router as notion_add_connector_router +from .obsidian_plugin_routes import router as obsidian_plugin_router from .onedrive_add_connector_route import router as onedrive_add_connector_router from .podcasts_routes import router as podcasts_router from .prompts_routes import router as prompts_router @@ -84,6 +85,7 @@ router.include_router(notion_add_connector_router) router.include_router(slack_add_connector_router) router.include_router(teams_add_connector_router) router.include_router(onedrive_add_connector_router) +router.include_router(obsidian_plugin_router) # Obsidian plugin push API router.include_router(discord_add_connector_router) router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py new file mode 100644 index 000000000..c7656332d --- /dev/null +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -0,0 +1,450 @@ +""" +Obsidian plugin ingestion routes. + +This is the public surface that the SurfSense Obsidian plugin +(``surfsense_obsidian/``) speaks to. It is a separate router from the +legacy server-path Obsidian connector — the legacy code stays in place +until the ``obsidian-legacy-cleanup`` plan ships. + +Endpoints +--------- + +- ``GET /api/v1/obsidian/health`` — version handshake +- ``POST /api/v1/obsidian/connect`` — register or get a vault row +- ``POST /api/v1/obsidian/sync`` — batch upsert +- ``POST /api/v1/obsidian/rename`` — batch rename +- ``DELETE /api/v1/obsidian/notes`` — batch soft-delete +- ``GET /api/v1/obsidian/manifest`` — reconcile manifest + +Auth contract +------------- + +Every endpoint requires ``Depends(current_active_user)`` — the same JWT +bearer the rest of the API uses; future PAT migration is transparent. + +API stability is provided by the ``/api/v1/...`` URL prefix and the +``capabilities`` array advertised on ``/health`` (additive only). There +is no plugin-version gate; "your plugin is out of date" notices are +delegated to Obsidian's built-in community-store updater. +""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import and_ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + SearchSpace, + User, + get_async_session, +) +from app.schemas.obsidian_plugin import ( + ConnectRequest, + ConnectResponse, + DeleteBatchRequest, + HealthResponse, + ManifestResponse, + RenameBatchRequest, + SyncBatchRequest, +) +from app.services.obsidian_plugin_indexer import ( + delete_note, + get_manifest, + rename_note, + upsert_note, +) +from app.users import current_active_user + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/obsidian", tags=["obsidian-plugin"]) + + +# Bumped manually whenever the wire contract gains a non-additive change. +# Additive (extra='ignore'-safe) changes do NOT bump this. +OBSIDIAN_API_VERSION = "1" + +# Capabilities advertised on /health and /connect. Plugins use this list +# for feature gating ("does this server understand attachments_v2?"). Add +# new strings, never rename/remove existing ones — older plugins ignore +# unknown entries safely. +OBSIDIAN_CAPABILITIES: list[str] = ["sync", "rename", "delete", "manifest"] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_handshake() -> dict[str, object]: + return { + "api_version": OBSIDIAN_API_VERSION, + "capabilities": list(OBSIDIAN_CAPABILITIES), + } + + +async def _resolve_vault_connector( + session: AsyncSession, + *, + user: User, + vault_id: str, +) -> SearchSourceConnector: + """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user. + + Looked up by the (user_id, connector_type, config['vault_id']) tuple + so users can have multiple vaults each backed by its own connector + row (one per search space). + """ + result = await session.execute( + select(SearchSourceConnector).where( + and_( + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.OBSIDIAN_CONNECTOR, + ) + ) + ) + candidates = result.scalars().all() + for connector in candidates: + cfg = connector.config or {} + if cfg.get("vault_id") == vault_id and cfg.get("source") == "plugin": + return connector + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "code": "VAULT_NOT_REGISTERED", + "message": ( + "No Obsidian plugin connector found for this vault. " + "Call POST /obsidian/connect first." + ), + "vault_id": vault_id, + }, + ) + + +async def _ensure_search_space_access( + session: AsyncSession, + *, + user: User, + search_space_id: int, +) -> SearchSpace: + """Confirm the user owns the requested search space. + + Plugin currently does not support shared search spaces (RBAC roles) + — that's a follow-up. Restricting to owner-only here keeps the + surface narrow and avoids leaking other members' connectors. + """ + result = await session.execute( + select(SearchSpace).where( + and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id) + ) + ) + space = result.scalars().first() + if space is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "code": "SEARCH_SPACE_FORBIDDEN", + "message": "You don't own that search space.", + }, + ) + return space + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@router.get("/health", response_model=HealthResponse) +async def obsidian_health( + user: User = Depends(current_active_user), +) -> HealthResponse: + """Return the API contract handshake. + + The plugin calls this once per ``onload`` and caches the result for + capability-gating decisions. + """ + return HealthResponse( + **_build_handshake(), + server_time_utc=datetime.now(UTC), + ) + + +@router.post("/connect", response_model=ConnectResponse) +async def obsidian_connect( + payload: ConnectRequest, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +) -> ConnectResponse: + """Register a vault, or return the existing connector row. + + Idempotent on the (user_id, OBSIDIAN_CONNECTOR, vault_id) tuple so + re-installing the plugin or reconnecting from a new device picks up + the same connector — and therefore the same documents. + """ + await _ensure_search_space_access( + session, user=user, search_space_id=payload.search_space_id + ) + + result = await session.execute( + select(SearchSourceConnector).where( + and_( + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.OBSIDIAN_CONNECTOR, + ) + ) + ) + existing: SearchSourceConnector | None = None + for candidate in result.scalars().all(): + cfg = candidate.config or {} + if cfg.get("vault_id") == payload.vault_id: + existing = candidate + break + + now_iso = datetime.now(UTC).isoformat() + + if existing is not None: + cfg = dict(existing.config or {}) + cfg.update( + { + "vault_id": payload.vault_id, + "vault_name": payload.vault_name, + "source": "plugin", + "plugin_version": payload.plugin_version, + "device_id": payload.device_id, + "last_connect_at": now_iso, + } + ) + if payload.device_label: + cfg["device_label"] = payload.device_label + cfg.pop("legacy", None) + cfg.pop("vault_path", None) + existing.config = cfg + existing.is_indexable = False + existing.search_space_id = payload.search_space_id + await session.commit() + await session.refresh(existing) + connector = existing + else: + connector = SearchSourceConnector( + name=f"Obsidian — {payload.vault_name}", + connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR, + is_indexable=False, + config={ + "vault_id": payload.vault_id, + "vault_name": payload.vault_name, + "source": "plugin", + "plugin_version": payload.plugin_version, + "device_id": payload.device_id, + "device_label": payload.device_label, + "files_synced": 0, + "last_connect_at": now_iso, + }, + user_id=user.id, + search_space_id=payload.search_space_id, + ) + session.add(connector) + await session.commit() + await session.refresh(connector) + + return ConnectResponse( + connector_id=connector.id, + vault_id=payload.vault_id, + search_space_id=connector.search_space_id, + **_build_handshake(), + ) + + +@router.post("/sync") +async def obsidian_sync( + payload: SyncBatchRequest, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +) -> dict[str, object]: + """Batch-upsert notes pushed by the plugin. + + Returns per-note ack so the plugin can dequeue successes and retry + failures. + """ + connector = await _resolve_vault_connector( + session, user=user, vault_id=payload.vault_id + ) + + results: list[dict[str, object]] = [] + indexed = 0 + failed = 0 + + for note in payload.notes: + try: + doc = await upsert_note( + session, connector=connector, payload=note, user_id=str(user.id) + ) + indexed += 1 + results.append( + {"path": note.path, "status": "ok", "document_id": doc.id} + ) + except HTTPException: + raise + except Exception as exc: + failed += 1 + logger.exception( + "obsidian /sync failed for path=%s vault=%s", + note.path, + payload.vault_id, + ) + results.append( + {"path": note.path, "status": "error", "error": str(exc)[:300]} + ) + + cfg = dict(connector.config or {}) + cfg["last_sync_at"] = datetime.now(UTC).isoformat() + cfg["files_synced"] = int(cfg.get("files_synced", 0)) + indexed + connector.config = cfg + await session.commit() + + return { + "vault_id": payload.vault_id, + "indexed": indexed, + "failed": failed, + "results": results, + } + + +@router.post("/rename") +async def obsidian_rename( + payload: RenameBatchRequest, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +) -> dict[str, object]: + """Apply a batch of vault rename events.""" + connector = await _resolve_vault_connector( + session, user=user, vault_id=payload.vault_id + ) + + results: list[dict[str, object]] = [] + renamed = 0 + missing = 0 + + for item in payload.renames: + try: + doc = await rename_note( + session, + connector=connector, + old_path=item.old_path, + new_path=item.new_path, + vault_id=payload.vault_id, + ) + if doc is None: + missing += 1 + results.append( + { + "old_path": item.old_path, + "new_path": item.new_path, + "status": "missing", + } + ) + else: + renamed += 1 + results.append( + { + "old_path": item.old_path, + "new_path": item.new_path, + "status": "ok", + "document_id": doc.id, + } + ) + except Exception as exc: + logger.exception( + "obsidian /rename failed for old=%s new=%s vault=%s", + item.old_path, + item.new_path, + payload.vault_id, + ) + results.append( + { + "old_path": item.old_path, + "new_path": item.new_path, + "status": "error", + "error": str(exc)[:300], + } + ) + + return { + "vault_id": payload.vault_id, + "renamed": renamed, + "missing": missing, + "results": results, + } + + +@router.delete("/notes") +async def obsidian_delete_notes( + payload: DeleteBatchRequest, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +) -> dict[str, object]: + """Soft-delete a batch of notes by vault-relative path.""" + connector = await _resolve_vault_connector( + session, user=user, vault_id=payload.vault_id + ) + + deleted = 0 + missing = 0 + results: list[dict[str, object]] = [] + for path in payload.paths: + try: + ok = await delete_note( + session, + connector=connector, + vault_id=payload.vault_id, + path=path, + ) + if ok: + deleted += 1 + results.append({"path": path, "status": "ok"}) + else: + missing += 1 + results.append({"path": path, "status": "missing"}) + except Exception as exc: + logger.exception( + "obsidian DELETE /notes failed for path=%s vault=%s", + path, + payload.vault_id, + ) + results.append( + {"path": path, "status": "error", "error": str(exc)[:300]} + ) + + return { + "vault_id": payload.vault_id, + "deleted": deleted, + "missing": missing, + "results": results, + } + + +@router.get("/manifest", response_model=ManifestResponse) +async def obsidian_manifest( + vault_id: str = Query(..., description="Plugin-side stable vault UUID"), + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +) -> ManifestResponse: + """Return the server-side ``{path: {hash, mtime}}`` manifest. + + Used by the plugin's ``onload`` reconcile to find files that were + edited or deleted while the plugin was offline. + """ + connector = await _resolve_vault_connector( + session, user=user, vault_id=vault_id + ) + return await get_manifest(session, connector=connector, vault_id=vault_id) diff --git a/surfsense_backend/app/schemas/obsidian_plugin.py b/surfsense_backend/app/schemas/obsidian_plugin.py new file mode 100644 index 000000000..c4c3cd8d4 --- /dev/null +++ b/surfsense_backend/app/schemas/obsidian_plugin.py @@ -0,0 +1,147 @@ +""" +Obsidian Plugin connector schemas. + +Wire format spoken between the SurfSense Obsidian plugin +(``surfsense_obsidian/``) and the FastAPI backend. + +Stability contract +------------------ +Every request and response schema sets ``model_config = ConfigDict(extra='ignore')``. +This is the API stability contract — not just hygiene: + +- Old plugins talking to a newer backend silently drop any new response fields + they don't understand instead of failing validation. +- New plugins talking to an older backend can include forward-looking request + fields (e.g. attachments metadata) without the older backend rejecting them. + +Hard breaking changes are reserved for the URL prefix (``/api/v2/...``). +Additive evolution is signaled via the ``capabilities`` array on +``HealthResponse`` / ``ConnectResponse`` — older plugins ignore unknown +capability strings safely. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +_PLUGIN_MODEL_CONFIG = ConfigDict(extra="ignore") + + +class _PluginBase(BaseModel): + """Base class for all plugin payload schemas. + + Carries the forward-compatibility config so subclasses don't have to + repeat it. + """ + + model_config = _PLUGIN_MODEL_CONFIG + + +class NotePayload(_PluginBase): + """One Obsidian note as pushed by the plugin. + + The plugin is the source of truth: ``content`` is the post-frontmatter + body, ``frontmatter``/``tags``/``headings``/etc. are precomputed by the + plugin via ``app.metadataCache`` so the backend doesn't have to re-parse. + """ + + vault_id: str = Field(..., description="Stable plugin-generated UUID for this vault") + path: str = Field(..., description="Vault-relative path, e.g. 'notes/foo.md'") + name: str = Field(..., description="File stem (no extension)") + extension: str = Field(default="md", description="File extension without leading dot") + content: str = Field(default="", description="Raw markdown body (post-frontmatter)") + + frontmatter: dict[str, Any] = Field(default_factory=dict) + tags: list[str] = Field(default_factory=list) + headings: list[str] = Field(default_factory=list) + resolved_links: list[str] = Field(default_factory=list) + unresolved_links: list[str] = Field(default_factory=list) + embeds: list[str] = Field(default_factory=list) + aliases: list[str] = Field(default_factory=list) + + content_hash: str = Field(..., description="Plugin-computed SHA-256 of the raw content") + mtime: datetime + ctime: datetime + + +class SyncBatchRequest(_PluginBase): + """Batch upsert. Plugin sends 10-20 notes per request to amortize HTTP overhead.""" + + vault_id: str + notes: list[NotePayload] = Field(default_factory=list, max_length=100) + + +class RenameItem(_PluginBase): + old_path: str + new_path: str + + +class RenameBatchRequest(_PluginBase): + vault_id: str + renames: list[RenameItem] = Field(default_factory=list, max_length=200) + + +class DeleteBatchRequest(_PluginBase): + vault_id: str + paths: list[str] = Field(default_factory=list, max_length=500) + + +class ManifestEntry(_PluginBase): + """One row of the server-side manifest used by the plugin to reconcile.""" + + hash: str + mtime: datetime + + +class ManifestResponse(_PluginBase): + """Path-keyed manifest of every non-deleted note for a vault.""" + + vault_id: str + items: dict[str, ManifestEntry] = Field(default_factory=dict) + + +class ConnectRequest(_PluginBase): + """First-call handshake to register or look up a vault connector row.""" + + vault_id: str + vault_name: str + search_space_id: int + plugin_version: str + device_id: str + device_label: str | None = Field( + default=None, + description="User-friendly device name shown in the web UI (e.g. 'iPad Pro').", + ) + + +class ConnectResponse(_PluginBase): + """Returned from POST /connect. + + Carries the same handshake fields as ``HealthResponse`` so the plugin + learns the contract on its very first call without an extra round-trip + to ``GET /health``. + """ + + connector_id: int + vault_id: str + search_space_id: int + api_version: str + capabilities: list[str] + + +class HealthResponse(_PluginBase): + """API contract handshake. + + The plugin calls ``GET /health`` once per ``onload`` and caches the + result. ``capabilities`` is a forward-extensible string list: future + additions (``'pat_auth'``, ``'scoped_pat'``, ``'attachments_v2'``, + ``'shared_search_spaces'``...) ship without breaking older plugins + because they only enable extra behavior, never gate existing endpoints. + """ + + api_version: str + capabilities: list[str] + server_time_utc: datetime diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py new file mode 100644 index 000000000..385c8e013 --- /dev/null +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -0,0 +1,400 @@ +""" +Obsidian plugin indexer service. + +Bridges the SurfSense Obsidian plugin's HTTP payloads +(see ``app/schemas/obsidian_plugin.py``) into the shared +``IndexingPipelineService``. + +Responsibilities: + +- ``upsert_note`` — push one note through the indexing pipeline; respects + unchanged content (skip) and version-snapshots existing rows before + rewrite. +- ``rename_note`` — rewrite path-derived fields (path metadata, + ``unique_identifier_hash``, ``source_url``) without re-indexing content. +- ``delete_note`` — soft delete with a tombstone in ``document_metadata`` + so reconciliation can distinguish "user explicitly killed this in the UI" + from "plugin hasn't synced yet". +- ``get_manifest`` — return ``{path: {hash, mtime}}`` for every non-deleted + note belonging to a vault, used by the plugin's reconcile pass on + ``onload``. + +Design notes +------------ + +The plugin's content hash and the backend's ``content_hash`` are computed +differently (plugin uses raw SHA-256 of the markdown body; backend salts +with ``search_space_id``). We persist the plugin's hash in +``document_metadata['plugin_content_hash']`` so the manifest endpoint can +return what the plugin sent — that's the only number the plugin can +compare without re-downloading content. +""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Any +from urllib.parse import quote + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + Document, + DocumentStatus, + DocumentType, + SearchSourceConnector, +) +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService +from app.schemas.obsidian_plugin import ( + ManifestEntry, + ManifestResponse, + NotePayload, +) +from app.services.llm_service import get_user_long_context_llm +from app.utils.document_converters import generate_unique_identifier_hash +from app.utils.document_versioning import create_version_snapshot + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _vault_path_unique_id(vault_id: str, path: str) -> str: + """Stable identifier for a note. Vault-scoped so the same path under two + different vaults doesn't collide.""" + return f"{vault_id}:{path}" + + +def _build_source_url(vault_name: str, path: str) -> str: + """Build the ``obsidian://`` deep link for the web UI's "Open in Obsidian" + button. Both segments are URL-encoded because vault names and paths can + contain spaces, ``#``, ``?``, etc. + """ + return ( + "obsidian://open" + f"?vault={quote(vault_name, safe='')}" + f"&file={quote(path, safe='')}" + ) + + +def _build_metadata( + payload: NotePayload, + *, + vault_name: str, + connector_id: int, + extra: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Flatten the rich plugin payload into the JSONB ``document_metadata`` + column. Keys here are what the chat UI / search UI surface to users. + """ + meta: dict[str, Any] = { + "source": "plugin", + "vault_id": payload.vault_id, + "vault_name": vault_name, + "file_path": payload.path, + "file_name": payload.name, + "extension": payload.extension, + "frontmatter": payload.frontmatter, + "tags": payload.tags, + "headings": payload.headings, + "outgoing_links": payload.resolved_links, + "unresolved_links": payload.unresolved_links, + "embeds": payload.embeds, + "aliases": payload.aliases, + "plugin_content_hash": payload.content_hash, + "mtime": payload.mtime.isoformat(), + "ctime": payload.ctime.isoformat(), + "connector_id": connector_id, + "url": _build_source_url(vault_name, payload.path), + } + if extra: + meta.update(extra) + return meta + + +def _build_document_string(payload: NotePayload, vault_name: str) -> str: + """Compose the indexable string the pipeline embeds and chunks. + + Mirrors the legacy obsidian indexer's METADATA + CONTENT framing so + existing search relevance heuristics keep working unchanged. + """ + tags_line = ", ".join(payload.tags) if payload.tags else "None" + links_line = ( + ", ".join(payload.resolved_links) if payload.resolved_links else "None" + ) + return ( + "\n" + f"Title: {payload.name}\n" + f"Vault: {vault_name}\n" + f"Path: {payload.path}\n" + f"Tags: {tags_line}\n" + f"Links to: {links_line}\n" + "\n\n" + "\n" + f"{payload.content}\n" + "\n" + ) + + +async def _find_existing_document( + session: AsyncSession, + *, + search_space_id: int, + vault_id: str, + path: str, +) -> Document | None: + unique_id = _vault_path_unique_id(vault_id, path) + uid_hash = generate_unique_identifier_hash( + DocumentType.OBSIDIAN_CONNECTOR, + unique_id, + search_space_id, + ) + result = await session.execute( + select(Document).where(Document.unique_identifier_hash == uid_hash) + ) + return result.scalars().first() + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def upsert_note( + session: AsyncSession, + *, + connector: SearchSourceConnector, + payload: NotePayload, + user_id: str, +) -> Document: + """Index or refresh a single note pushed by the plugin. + + Returns the resulting ``Document`` (whether newly created, updated, or + a skip-because-unchanged hit). + """ + vault_name: str = (connector.config or {}).get("vault_name") or "Vault" + search_space_id = connector.search_space_id + + existing = await _find_existing_document( + session, + search_space_id=search_space_id, + vault_id=payload.vault_id, + path=payload.path, + ) + + plugin_hash = payload.content_hash + if existing is not None: + existing_meta = existing.document_metadata or {} + was_tombstoned = bool(existing_meta.get("deleted_at")) + + if ( + not was_tombstoned + and existing_meta.get("plugin_content_hash") == plugin_hash + and DocumentStatus.is_state(existing.status, DocumentStatus.READY) + ): + return existing + + try: + await create_version_snapshot(session, existing) + except Exception: + logger.debug( + "version snapshot failed for obsidian doc %s", + existing.id, + exc_info=True, + ) + + document_string = _build_document_string(payload, vault_name) + metadata = _build_metadata( + payload, + vault_name=vault_name, + connector_id=connector.id, + ) + + connector_doc = ConnectorDocument( + title=payload.name, + source_markdown=document_string, + unique_id=_vault_path_unique_id(payload.vault_id, payload.path), + document_type=DocumentType.OBSIDIAN_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector.id, + created_by_id=str(user_id), + should_summarize=connector.enable_summary, + fallback_summary=f"Obsidian Note: {payload.name}\n\n{payload.content}", + metadata=metadata, + ) + + pipeline = IndexingPipelineService(session) + prepared = await pipeline.prepare_for_indexing([connector_doc]) + if not prepared: + if existing is not None: + return existing + raise RuntimeError( + f"Indexing pipeline rejected obsidian note {payload.path}" + ) + + document = prepared[0] + + llm = await get_user_long_context_llm(session, str(user_id), search_space_id) + return await pipeline.index(document, connector_doc, llm) + + +async def rename_note( + session: AsyncSession, + *, + connector: SearchSourceConnector, + old_path: str, + new_path: str, + vault_id: str, +) -> Document | None: + """Rewrite path-derived columns without re-indexing content. + + Returns the updated document, or ``None`` if no row matched the + ``old_path`` (this happens when the plugin is renaming a file that was + never synced — safe to ignore, the next ``sync`` will create it under + the new path). + """ + vault_name: str = (connector.config or {}).get("vault_name") or "Vault" + search_space_id = connector.search_space_id + + existing = await _find_existing_document( + session, + search_space_id=search_space_id, + vault_id=vault_id, + path=old_path, + ) + if existing is None: + return None + + new_unique_id = _vault_path_unique_id(vault_id, new_path) + new_uid_hash = generate_unique_identifier_hash( + DocumentType.OBSIDIAN_CONNECTOR, + new_unique_id, + search_space_id, + ) + + collision = await session.execute( + select(Document).where( + and_( + Document.unique_identifier_hash == new_uid_hash, + Document.id != existing.id, + ) + ) + ) + collision_row = collision.scalars().first() + if collision_row is not None: + logger.warning( + "obsidian rename target already exists " + "(vault=%s old=%s new=%s); skipping rename so the next /sync " + "can resolve the conflict via content_hash", + vault_id, + old_path, + new_path, + ) + return existing + + new_filename = new_path.rsplit("/", 1)[-1] + new_stem = new_filename.rsplit(".", 1)[0] if "." in new_filename else new_filename + + existing.unique_identifier_hash = new_uid_hash + existing.title = new_stem + + meta = dict(existing.document_metadata or {}) + meta["file_path"] = new_path + meta["file_name"] = new_stem + meta["url"] = _build_source_url(vault_name, new_path) + existing.document_metadata = meta + existing.updated_at = datetime.now(UTC) + + await session.commit() + return existing + + +async def delete_note( + session: AsyncSession, + *, + connector: SearchSourceConnector, + vault_id: str, + path: str, +) -> bool: + """Soft-delete via tombstone in ``document_metadata``. + + The row is *not* removed and chunks are *not* dropped, so existing + citations in chat threads remain resolvable. The manifest endpoint + filters tombstoned rows out, so the plugin's reconcile pass will not + see this path and won't try to "resurrect" a note the user deleted in + the SurfSense UI. + + Returns True if a row was tombstoned, False if no matching row existed. + """ + existing = await _find_existing_document( + session, + search_space_id=connector.search_space_id, + vault_id=vault_id, + path=path, + ) + if existing is None: + return False + + meta = dict(existing.document_metadata or {}) + if meta.get("deleted_at"): + return True + + meta["deleted_at"] = datetime.now(UTC).isoformat() + meta["deleted_by_source"] = "plugin" + existing.document_metadata = meta + existing.updated_at = datetime.now(UTC) + + await session.commit() + return True + + +async def get_manifest( + session: AsyncSession, + *, + connector: SearchSourceConnector, + vault_id: str, +) -> ManifestResponse: + """Return ``{path: {hash, mtime}}`` for every non-deleted note in this + vault. + + The plugin compares this against its local vault on every ``onload`` to + catch up edits made while offline. Rows missing ``plugin_content_hash`` + (e.g. tombstoned, or somehow indexed without going through this + service) are excluded so the plugin doesn't get confused by partial + data. + """ + result = await session.execute( + select(Document).where( + and_( + Document.search_space_id == connector.search_space_id, + Document.connector_id == connector.id, + Document.document_type == DocumentType.OBSIDIAN_CONNECTOR, + ) + ) + ) + + items: dict[str, ManifestEntry] = {} + for doc in result.scalars().all(): + meta = doc.document_metadata or {} + if meta.get("deleted_at"): + continue + if meta.get("vault_id") != vault_id: + continue + path = meta.get("file_path") + plugin_hash = meta.get("plugin_content_hash") + mtime_raw = meta.get("mtime") + if not path or not plugin_hash or not mtime_raw: + continue + try: + mtime = datetime.fromisoformat(mtime_raw) + except ValueError: + continue + items[path] = ManifestEntry(hash=plugin_hash, mtime=mtime) + + return ManifestResponse(vault_id=vault_id, items=items) From ee2fb79e75490013d85aa0c04f4fdf7b791e6079 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 20 Apr 2026 04:03:45 +0530 Subject: [PATCH 005/299] feat: update Obsidian connector to support plugin-based syncing and improve UI components --- .../components/obsidian-connect-form.tsx | 466 +++++++----------- .../connect-forms/connector-benefits.ts | 10 +- .../components/obsidian-config.tsx | 346 ++++++++----- .../constants/connector-constants.ts | 2 +- 4 files changed, 415 insertions(+), 409 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx index 08c1dd30c..b4bd76e8f 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx @@ -1,314 +1,212 @@ "use client"; -import { zodResolver } from "@hookform/resolvers/zod"; -import { Info } from "lucide-react"; -import type { FC } from "react"; -import { useRef, useState } from "react"; -import { useForm } from "react-hook-form"; -import * as z from "zod"; +import { Check, Copy, Download, Info, KeyRound, Settings2 } from "lucide-react"; +import { type FC, useCallback, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; -import { - Form, - FormControl, - FormDescription, - FormField, - FormItem, - FormLabel, - FormMessage, -} from "@/components/ui/form"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Switch } from "@/components/ui/switch"; +import { Button } from "@/components/ui/button"; +import { useApiKey } from "@/hooks/use-api-key"; +import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils"; import { EnumConnectorName } from "@/contracts/enums/connector"; import { getConnectorBenefits } from "../connector-benefits"; import type { ConnectFormProps } from "../index"; -const obsidianConnectorFormSchema = z.object({ - name: z.string().min(3, { - message: "Connector name must be at least 3 characters.", - }), - vault_path: z.string().min(1, { - message: "Vault path is required.", - }), - vault_name: z.string().min(1, { - message: "Vault name is required.", - }), - exclude_folders: z.string().optional(), - include_attachments: z.boolean(), -}); +const PLUGIN_RELEASES_URL = + "https://github.com/MODSetter/SurfSense/releases?q=obsidian&expanded=true"; -type ObsidianConnectorFormValues = z.infer; +const BACKEND_URL = + process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ?? "https://api.surfsense.com"; -export const ObsidianConnectForm: FC = ({ onSubmit, isSubmitting }) => { - const isSubmittingRef = useRef(false); - const [periodicEnabled, setPeriodicEnabled] = useState(true); - const [frequencyMinutes, setFrequencyMinutes] = useState("60"); - const form = useForm({ - resolver: zodResolver(obsidianConnectorFormSchema), - defaultValues: { - name: "Obsidian Vault", - vault_path: "", - vault_name: "", - exclude_folders: ".obsidian,.trash", - include_attachments: false, - }, - }); +/** + * Obsidian connect form for the plugin-only architecture. + * + * The legacy `vault_path` form was removed because it only worked on + * self-hosted with a server-side bind mount and broke for everyone else. + * The plugin pushes data over HTTPS so this UI is purely instructional — + * there is no backend create call here. The connector row is created + * server-side the first time the plugin calls `POST /obsidian/connect`. + * + * The footer "Connect" button in `ConnectorConnectView` triggers this + * form's submit; we just close the dialog (`onBack()`) since there's + * nothing to validate or persist from this side. + */ +export const ObsidianConnectForm: FC = ({ onBack }) => { + const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); + const [copiedUrl, setCopiedUrl] = useState(false); + const urlCopyTimerRef = useRef | undefined>( + undefined + ); - const handleSubmit = async (values: ObsidianConnectorFormValues) => { - // Prevent multiple submissions - if (isSubmittingRef.current || isSubmitting) { - return; - } + const copyServerUrl = useCallback(async () => { + const ok = await copyToClipboardUtil(BACKEND_URL); + if (!ok) return; + setCopiedUrl(true); + if (urlCopyTimerRef.current) clearTimeout(urlCopyTimerRef.current); + urlCopyTimerRef.current = setTimeout(() => setCopiedUrl(false), 2000); + }, []); - isSubmittingRef.current = true; - try { - // Parse exclude_folders into an array - const excludeFolders = values.exclude_folders - ? values.exclude_folders - .split(",") - .map((f) => f.trim()) - .filter(Boolean) - : [".obsidian", ".trash"]; - - await onSubmit({ - name: values.name, - connector_type: EnumConnectorName.OBSIDIAN_CONNECTOR, - config: { - vault_path: values.vault_path, - vault_name: values.vault_name, - exclude_folders: excludeFolders, - include_attachments: values.include_attachments, - }, - is_indexable: true, - is_active: true, - last_indexed_at: null, - periodic_indexing_enabled: periodicEnabled, - indexing_frequency_minutes: periodicEnabled ? Number.parseInt(frequencyMinutes, 10) : null, - next_scheduled_at: null, - periodicEnabled, - frequencyMinutes, - }); - } finally { - isSubmittingRef.current = false; - } + const handleSubmit = (event: React.FormEvent) => { + event.preventDefault(); + onBack(); }; return (
- + {/* Form is intentionally empty so the footer Connect button is a no-op + that just closes the dialog (see component-level docstring). */} +
+ + - Self-Hosted Only + Plugin-based sync - This connector requires direct file system access and only works with self-hosted - SurfSense installations. + SurfSense now syncs Obsidian via an official plugin that runs inside + Obsidian itself. Works on desktop and mobile, in cloud and self-hosted + deployments — no server-side vault mounts required. -
- - - ( - - Connector Name - - - - - A friendly name to identify this connector. - - - - )} - /> + {/* Step 1 — Install plugin */} +
+
+
+ 1 +
+

Install the plugin

+
+

+ Grab the latest SurfSense plugin release. Once it's in the community + store, you'll also be able to install it from{" "} + Settings → Community plugins{" "} + inside Obsidian. +

+ + + +
- ( - - Vault Path - - - - - The absolute path to your Obsidian vault on the server. This must be accessible - from the SurfSense backend. - - - - )} - /> + {/* Step 2 — Copy API key */} +
+
+
+ 2 +
+

+ Copy your API key +

+ +
+

+ Paste this into the plugin's API token{" "} + setting. The token expires after 24 hours; long-lived personal access + tokens are coming in a future release. +

- ( - - Vault Name - - - - - A display name for your vault. This will be used in search results. - - - - )} - /> - - ( - - Exclude Folders - - - - - Comma-separated list of folder names to exclude from indexing. - - - - )} - /> - - ( - -
- Include Attachments - - Index attachment folders and embedded files (images, PDFs, etc.) - -
- - - -
- )} - /> - - {/* Indexing Configuration */} -
-

Indexing Configuration

- - {/* Periodic Sync Config */} -
-
-
-

Enable Periodic Sync

-

- Automatically re-index at regular intervals -

-
- -
- - {periodicEnabled && ( -
-
- - -
-
- )} -
+ {isLoading ? ( +
+ ) : apiKey ? ( +
+
+

+ {apiKey} +

- - -
+ +
+ ) : ( +

+ No API key available — try refreshing the page. +

+ )} +
+ + {/* Step 3 — Server URL */} +
+
+
+ 3 +
+

+ Point the plugin at this server +

+
+

+ Paste this URL into the plugin's Server URL{" "} + setting. We auto-detect it from your current dashboard origin. +

+
+
+

+ {BACKEND_URL} +

+
+ +
+
+ + {/* Step 4 — Pick search space */} +
+
+
+ 4 +
+

+ Pick this search space +

+ +
+

+ In the plugin's Search space{" "} + setting, choose the search space you want this vault to sync into. + The connector will appear here automatically once the plugin makes + its first sync. +

+
- {/* What you get section */} {getConnectorBenefits(EnumConnectorName.OBSIDIAN_CONNECTOR) && ( -
-

+
+

What you get with Obsidian integration:

-
    - {getConnectorBenefits(EnumConnectorName.OBSIDIAN_CONNECTOR)?.map((benefit) => ( -
  • {benefit}
  • - ))} +
      + {getConnectorBenefits(EnumConnectorName.OBSIDIAN_CONNECTOR)?.map( + (benefit) => ( +
    • {benefit}
    • + ) + )}
)} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/connector-benefits.ts b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/connector-benefits.ts index 0dc093100..f4883fa36 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/connector-benefits.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/connector-benefits.ts @@ -104,11 +104,11 @@ export function getConnectorBenefits(connectorType: string): string[] | null { "No manual indexing required - meetings are added automatically", ], OBSIDIAN_CONNECTOR: [ - "Search through all your Obsidian notes and knowledge base", - "Access note content with YAML frontmatter metadata preserved", - "Wiki-style links ([[note]]) and #tags are indexed", - "Connect your personal knowledge base directly to your search space", - "Incremental sync - only changed files are re-indexed", + "Search through all of your Obsidian notes", + "Realtime sync as you create, edit, rename, or delete notes", + "YAML frontmatter, [[wiki links]], and #tags are preserved and indexed", + "Open any chat citation straight back in Obsidian via deep links", + "Each device is identifiable, so you can revoke a vault from one machine", "Full support for your vault's folder structure", ], }; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx index 3da1d6e7e..acea1c51b 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx @@ -1,94 +1,58 @@ "use client"; -import type { FC } from "react"; -import { useState } from "react"; +import { AlertTriangle, Check, Copy, Download, Info } from "lucide-react"; +import { type FC, useCallback, useMemo, useRef, useState } from "react"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { Switch } from "@/components/ui/switch"; +import { useApiKey } from "@/hooks/use-api-key"; +import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils"; import type { ConnectorConfigProps } from "../index"; export interface ObsidianConfigProps extends ConnectorConfigProps { onNameChange?: (name: string) => void; } +const PLUGIN_RELEASES_URL = + "https://github.com/MODSetter/SurfSense/releases?q=obsidian&expanded=true"; + +function formatTimestamp(value: unknown): string { + if (typeof value !== "string" || !value) return "—"; + const d = new Date(value); + if (Number.isNaN(d.getTime())) return value; + return d.toLocaleString(); +} + +/** + * Obsidian connector config view. + * + * Renders one of two modes depending on the connector's `config`: + * + * 1. **Plugin connector** (`config.source === "plugin"`) — read-only stats + * panel showing what the plugin most recently reported. + * 2. **Legacy server-path connector** (`config.legacy === true`, set by the + * Phase 3 alembic) — migration banner plus an "Install Plugin" CTA. + * The user's existing notes stay searchable; only background sync stops. + */ export const ObsidianConfig: FC = ({ connector, - onConfigChange, onNameChange, }) => { - const [vaultPath, setVaultPath] = useState( - (connector.config?.vault_path as string) || "" - ); - const [vaultName, setVaultName] = useState( - (connector.config?.vault_name as string) || "" - ); - const [excludeFolders, setExcludeFolders] = useState(() => { - const folders = connector.config?.exclude_folders; - if (Array.isArray(folders)) { - return folders.join(", "); - } - return (folders as string) || ".obsidian, .trash"; - }); - const [includeAttachments, setIncludeAttachments] = useState( - (connector.config?.include_attachments as boolean) || false - ); const [name, setName] = useState(connector.name || ""); - - const handleVaultPathChange = (value: string) => { - setVaultPath(value); - if (onConfigChange) { - onConfigChange({ - ...connector.config, - vault_path: value, - }); - } - }; - - const handleVaultNameChange = (value: string) => { - setVaultName(value); - if (onConfigChange) { - onConfigChange({ - ...connector.config, - vault_name: value, - }); - } - }; - - const handleExcludeFoldersChange = (value: string) => { - setExcludeFolders(value); - const foldersArray = value - .split(",") - .map((f) => f.trim()) - .filter(Boolean); - if (onConfigChange) { - onConfigChange({ - ...connector.config, - exclude_folders: foldersArray, - }); - } - }; - - const handleIncludeAttachmentsChange = (value: boolean) => { - setIncludeAttachments(value); - if (onConfigChange) { - onConfigChange({ - ...connector.config, - include_attachments: value, - }); - } - }; + const config = (connector.config ?? {}) as Record; + const isLegacy = config.legacy === true; + const isPlugin = config.source === "plugin"; const handleNameChange = (value: string) => { setName(value); - if (onNameChange) { - onNameChange(value); - } + onNameChange?.(value); }; return (
- {/* Connector Name */} -
+ {/* Connector name (always editable) */} +
= ({
- {/* Configuration */} -
-
-

- Vault Configuration -

-
+ {isLegacy ? ( + + ) : isPlugin ? ( + + ) : ( + + )} +
+ ); +}; -
-
- - handleVaultPathChange(e.target.value)} - placeholder="/path/to/your/obsidian/vault" - className="border-slate-400/20 focus-visible:border-slate-400/40 font-mono" - /> -

- The absolute path to your Obsidian vault on the server. -

-
+const LegacyBanner: FC = () => { + return ( +
+ + + + This connector has been migrated + + + This Obsidian connector used the legacy server-path method, which has + been removed. To resume syncing, install the SurfSense Obsidian + plugin and connect with this account. Your existing notes remain + searchable. After the plugin re-indexes your vault, you can delete + this connector to remove older copies. + + -
- - handleVaultNameChange(e.target.value)} - placeholder="My Knowledge Base" - className="border-slate-400/20 focus-visible:border-slate-400/40" - /> -

- A display name for your vault in search results. -

-
+ + + -
- - handleExcludeFoldersChange(e.target.value)} - placeholder=".obsidian, .trash, templates" - className="border-slate-400/20 focus-visible:border-slate-400/40 font-mono" - /> -

- Comma-separated list of folder names to exclude from indexing. -

-
+ +
+ ); +}; -
-
- -

- Index attachment folders and embedded files +const PluginStats: FC<{ config: Record }> = ({ config }) => { + const stats: { label: string; value: string }[] = useMemo(() => { + const filesSynced = config.files_synced; + return [ + { label: "Vault", value: (config.vault_name as string) || "—" }, + { + label: "Plugin version", + value: (config.plugin_version as string) || "—", + }, + { + label: "Device", + value: + (config.device_label as string) || + (config.device_id as string) || + "—", + }, + { + label: "Last sync", + value: formatTimestamp(config.last_sync_at), + }, + { + label: "Files synced", + value: + typeof filesSynced === "number" ? filesSynced.toLocaleString() : "—", + }, + ]; + }, [config]); + + return ( +

+ + + Plugin connected + + Edits in Obsidian sync over HTTPS. To stop syncing, disable or + uninstall the plugin in Obsidian, or delete this connector. + + + +
+

Vault status

+
+ {stats.map((stat) => ( +
+
+ {stat.label} +
+
+ {stat.value} +
+
+ ))} +
+
+
+ ); +}; + +const UnknownConnectorState: FC = () => ( + + + Unrecognized config + + This connector has neither plugin metadata nor a legacy marker. It may + predate the migration — you can safely delete it and re-install the + SurfSense Obsidian plugin to resume syncing. + + +); + +const ApiKeyReminder: FC = () => { + const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); + const [copiedUrl, setCopiedUrl] = useState(false); + const urlCopyTimerRef = useRef | undefined>( + undefined + ); + + const backendUrl = + process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ?? "https://api.surfsense.com"; + + const copyServerUrl = useCallback(async () => { + const ok = await copyToClipboardUtil(backendUrl); + if (!ok) return; + setCopiedUrl(true); + if (urlCopyTimerRef.current) clearTimeout(urlCopyTimerRef.current); + urlCopyTimerRef.current = setTimeout(() => setCopiedUrl(false), 2000); + }, [backendUrl]); + + return ( +
+

+ Plugin connection details +

+

+ Paste these into the plugin's settings inside Obsidian. +

+ +
+ + {isLoading ? ( +
+ ) : ( +
+
+

+ {apiKey || "No API key available"}

- +
+ )} +

+ Token expires after 24 hours; long-lived tokens are coming in a + future release. +

+
+ +
+ +
+
+

+ {backendUrl} +

+
+
diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index da6885ffe..86d214134 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -180,7 +180,7 @@ export const OTHER_CONNECTORS = [ { id: "obsidian-connector", title: "Obsidian", - description: "Index your Obsidian vault (Local folder scan on Desktop)", + description: "Sync your Obsidian vault on desktop or mobile via the SurfSense plugin", connectorType: EnumConnectorName.OBSIDIAN_CONNECTOR, }, ] as const; From 60d9e7ed8c95503ff69c53e295ea137b6156a2ca Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 20 Apr 2026 04:04:19 +0530 Subject: [PATCH 006/299] feat: introduce SurfSense plugin for Obsidian with syncing capabilities and enhanced settings management --- manifest.json | 10 + surfsense_obsidian/.github/workflows/lint.yml | 28 - surfsense_obsidian/eslint.config.mts | 21 + surfsense_obsidian/manifest.json | 15 +- surfsense_obsidian/package-lock.json | 10 +- surfsense_obsidian/package.json | 21 +- surfsense_obsidian/src/api-client.ts | 248 +++++++++ surfsense_obsidian/src/excludes.ts | 66 +++ surfsense_obsidian/src/main.ts | 257 ++++++--- surfsense_obsidian/src/payload.ts | 162 ++++++ surfsense_obsidian/src/queue.ts | 237 ++++++++ surfsense_obsidian/src/settings.ts | 330 +++++++++++- surfsense_obsidian/src/status-bar.ts | 61 +++ surfsense_obsidian/src/sync-engine.ts | 505 ++++++++++++++++++ surfsense_obsidian/src/types.ts | 145 +++++ surfsense_obsidian/styles.css | 66 ++- surfsense_obsidian/versions.json | 2 +- .../hooks/use-connector-dialog.ts | 32 +- versions.json | 3 + 19 files changed, 2044 insertions(+), 175 deletions(-) create mode 100644 manifest.json delete mode 100644 surfsense_obsidian/.github/workflows/lint.yml create mode 100644 surfsense_obsidian/src/api-client.ts create mode 100644 surfsense_obsidian/src/excludes.ts create mode 100644 surfsense_obsidian/src/payload.ts create mode 100644 surfsense_obsidian/src/queue.ts create mode 100644 surfsense_obsidian/src/status-bar.ts create mode 100644 surfsense_obsidian/src/sync-engine.ts create mode 100644 surfsense_obsidian/src/types.ts create mode 100644 versions.json diff --git a/manifest.json b/manifest.json new file mode 100644 index 000000000..f65bb8844 --- /dev/null +++ b/manifest.json @@ -0,0 +1,10 @@ +{ + "id": "surfsense", + "name": "SurfSense", + "version": "0.1.0", + "minAppVersion": "1.4.0", + "description": "Sync your Obsidian vault to SurfSense for AI-powered search across all your knowledge sources.", + "author": "SurfSense", + "authorUrl": "https://github.com/MODSetter/SurfSense", + "isDesktopOnly": false +} diff --git a/surfsense_obsidian/.github/workflows/lint.yml b/surfsense_obsidian/.github/workflows/lint.yml deleted file mode 100644 index 7748ceb77..000000000 --- a/surfsense_obsidian/.github/workflows/lint.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Node.js build - -on: - push: - branches: ["**"] - pull_request: - branches: ["**"] - -jobs: - build: - runs-on: ubuntu-latest - - strategy: - matrix: - node-version: [20.x, 22.x] - # See supported Node.js release schedule at https://nodejs.org/en/about/releases/ - - steps: - - uses: actions/checkout@v4 - - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v4 - with: - node-version: ${{ matrix.node-version }} - cache: "npm" - - run: npm ci - - run: npm run build --if-present - - run: npm run lint - diff --git a/surfsense_obsidian/eslint.config.mts b/surfsense_obsidian/eslint.config.mts index 3062c4a07..a2615ae6d 100644 --- a/surfsense_obsidian/eslint.config.mts +++ b/surfsense_obsidian/eslint.config.mts @@ -22,6 +22,27 @@ export default tseslint.config( }, }, ...obsidianmd.configs.recommended, + { + plugins: { obsidianmd }, + rules: { + "obsidianmd/ui/sentence-case": [ + "error", + { + brands: [ + "Surfsense", + "iOS", + "iPadOS", + "macOS", + "Windows", + "Android", + "Linux", + "Obsidian", + "Markdown", + ], + }, + ], + }, + }, globalIgnores([ "node_modules", "dist", diff --git a/surfsense_obsidian/manifest.json b/surfsense_obsidian/manifest.json index dfa940ed8..f65bb8844 100644 --- a/surfsense_obsidian/manifest.json +++ b/surfsense_obsidian/manifest.json @@ -1,11 +1,10 @@ { - "id": "sample-plugin", - "name": "Sample Plugin", - "version": "1.0.0", - "minAppVersion": "0.15.0", - "description": "Demonstrates some of the capabilities of the Obsidian API.", - "author": "Obsidian", - "authorUrl": "https://obsidian.md", - "fundingUrl": "https://obsidian.md/pricing", + "id": "surfsense", + "name": "SurfSense", + "version": "0.1.0", + "minAppVersion": "1.4.0", + "description": "Sync your Obsidian vault to SurfSense for AI-powered search across all your knowledge sources.", + "author": "SurfSense", + "authorUrl": "https://github.com/MODSetter/SurfSense", "isDesktopOnly": false } diff --git a/surfsense_obsidian/package-lock.json b/surfsense_obsidian/package-lock.json index d0dac397c..501ff01f9 100644 --- a/surfsense_obsidian/package-lock.json +++ b/surfsense_obsidian/package-lock.json @@ -1,13 +1,13 @@ { - "name": "obsidian-sample-plugin", - "version": "1.0.0", + "name": "surfsense-obsidian", + "version": "0.1.0", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "obsidian-sample-plugin", - "version": "1.0.0", - "license": "0-BSD", + "name": "surfsense-obsidian", + "version": "0.1.0", + "license": "Apache-2.0", "dependencies": { "obsidian": "latest" }, diff --git a/surfsense_obsidian/package.json b/surfsense_obsidian/package.json index 17268d72a..aca91f9e3 100644 --- a/surfsense_obsidian/package.json +++ b/surfsense_obsidian/package.json @@ -1,7 +1,7 @@ { - "name": "obsidian-sample-plugin", - "version": "1.0.0", - "description": "This is a sample plugin for Obsidian (https://obsidian.md)", + "name": "surfsense-obsidian", + "version": "0.1.0", + "description": "SurfSense plugin for Obsidian: sync your vault to SurfSense for AI-powered search.", "main": "main.js", "type": "module", "scripts": { @@ -10,18 +10,23 @@ "version": "node version-bump.mjs && git add manifest.json versions.json", "lint": "eslint ." }, - "keywords": [], - "license": "0-BSD", + "keywords": [ + "obsidian", + "surfsense", + "sync", + "search" + ], + "license": "Apache-2.0", "devDependencies": { + "@eslint/js": "9.30.1", "@types/node": "^16.11.6", "esbuild": "0.25.5", "eslint-plugin-obsidianmd": "0.1.9", "globals": "14.0.0", + "jiti": "2.6.1", "tslib": "2.4.0", "typescript": "^5.8.3", - "typescript-eslint": "8.35.1", - "@eslint/js": "9.30.1", - "jiti": "2.6.1" + "typescript-eslint": "8.35.1" }, "dependencies": { "obsidian": "latest" diff --git a/surfsense_obsidian/src/api-client.ts b/surfsense_obsidian/src/api-client.ts new file mode 100644 index 000000000..d686f661f --- /dev/null +++ b/surfsense_obsidian/src/api-client.ts @@ -0,0 +1,248 @@ +import { Notice, requestUrl, type RequestUrlParam, type RequestUrlResponse } from "obsidian"; +import type { + ConnectResponse, + HealthResponse, + ManifestResponse, + NotePayload, + RenameItem, + SearchSpace, +} from "./types"; + +/** + * SurfSense backend client used by the Obsidian plugin. + * + * Mobile-safety contract (must hold for every transitive import): + * - Use Obsidian `requestUrl` only — no `fetch`, no `axios`, no + * `node:http`, no `node:https`. CORS is bypassed and mobile works. + * - No top-level `node:*` imports anywhere reachable from this file. + * - Hashing happens elsewhere via Web Crypto, not `node:crypto`. + * + * Auth + wire contract: + * - Every request carries `Authorization: Bearer ` only. No + * custom headers — the backend identifies the caller from the JWT + * and feature-detects the API via the `capabilities` array on + * `/health` and `/connect`. + * - 401 surfaces as `AuthError` so the orchestrator can show the + * "token expired, paste a fresh one" UX. + * - HealthResponse / ConnectResponse use index signatures so any + * additive backend field (e.g. new capabilities) parses without + * breaking the decoder. This mirrors `ConfigDict(extra='ignore')` + * on the server side. + */ + +export class AuthError extends Error { + constructor(message: string) { + super(message); + this.name = "AuthError"; + } +} + +export class TransientError extends Error { + readonly status: number; + constructor(status: number, message: string) { + super(message); + this.name = "TransientError"; + this.status = status; + } +} + +export class PermanentError extends Error { + readonly status: number; + constructor(status: number, message: string) { + super(message); + this.name = "PermanentError"; + this.status = status; + } +} + +export interface ApiClientOptions { + getServerUrl: () => string; + getToken: () => string; + pluginVersion: string; + onAuthError?: () => void; +} + +export class SurfSenseApiClient { + private readonly opts: ApiClientOptions; + + constructor(opts: ApiClientOptions) { + this.opts = opts; + } + + updateOptions(partial: Partial): void { + Object.assign(this.opts, partial); + } + + get pluginVersion(): string { + return this.opts.pluginVersion; + } + + async health(): Promise { + return await this.request("GET", "/api/v1/obsidian/health"); + } + + async listSearchSpaces(): Promise { + const resp = await this.request( + "GET", + "/api/v1/searchspaces/" + ); + if (Array.isArray(resp)) return resp; + if (resp && Array.isArray((resp as { items?: SearchSpace[] }).items)) { + return (resp as { items: SearchSpace[] }).items; + } + return []; + } + + async verifyToken(): Promise<{ ok: true; health: HealthResponse }> { + // /health is gated by current_active_user, so a successful response + // transitively proves the token works. Cheaper than fetching a list. + const health = await this.health(); + return { ok: true, health }; + } + + async connect(input: { + searchSpaceId: number; + vaultId: string; + vaultName: string; + deviceId: string; + deviceLabel: string; + }): Promise { + return await this.request( + "POST", + `/api/v1/obsidian/connect?search_space_id=${encodeURIComponent( + String(input.searchSpaceId) + )}`, + { + vault_id: input.vaultId, + vault_name: input.vaultName, + plugin_version: this.opts.pluginVersion, + device_id: input.deviceId, + device_label: input.deviceLabel, + } + ); + } + + async syncBatch(input: { + vaultId: string; + notes: NotePayload[]; + }): Promise<{ accepted: number; rejected: string[] }> { + const resp = await this.request<{ accepted?: number; rejected?: string[] }>( + "POST", + "/api/v1/obsidian/sync", + { vault_id: input.vaultId, notes: input.notes } + ); + return { + accepted: typeof resp.accepted === "number" ? resp.accepted : input.notes.length, + rejected: Array.isArray(resp.rejected) ? resp.rejected : [], + }; + } + + async renameBatch(input: { + vaultId: string; + renames: Pick[]; + }): Promise<{ renamed: number }> { + const resp = await this.request<{ renamed?: number }>( + "POST", + "/api/v1/obsidian/rename", + { + vault_id: input.vaultId, + renames: input.renames.map((r) => ({ + old_path: r.oldPath, + new_path: r.newPath, + })), + } + ); + return { renamed: typeof resp.renamed === "number" ? resp.renamed : 0 }; + } + + async deleteBatch(input: { + vaultId: string; + paths: string[]; + }): Promise<{ deleted: number }> { + const resp = await this.request<{ deleted?: number }>( + "DELETE", + "/api/v1/obsidian/notes", + { vault_id: input.vaultId, paths: input.paths } + ); + return { deleted: typeof resp.deleted === "number" ? resp.deleted : 0 }; + } + + async getManifest(vaultId: string): Promise { + return await this.request( + "GET", + `/api/v1/obsidian/manifest?vault_id=${encodeURIComponent(vaultId)}` + ); + } + + private async request( + method: RequestUrlParam["method"], + path: string, + body?: unknown + ): Promise { + const baseUrl = this.opts.getServerUrl().replace(/\/+$/, ""); + const token = this.opts.getToken(); + if (!token) { + throw new AuthError("Missing API token. Open SurfSense settings to paste one."); + } + const headers: Record = { + Authorization: `Bearer ${token}`, + Accept: "application/json", + }; + if (body !== undefined) headers["Content-Type"] = "application/json"; + + let resp: RequestUrlResponse; + try { + resp = await requestUrl({ + url: `${baseUrl}${path}`, + method, + headers, + body: body === undefined ? undefined : JSON.stringify(body), + throw: false, + }); + } catch (err) { + throw new TransientError(0, `Network error: ${(err as Error).message}`); + } + + if (resp.status >= 200 && resp.status < 300) { + return parseJson(resp); + } + + const detail = extractDetail(resp); + + if (resp.status === 401) { + this.opts.onAuthError?.(); + new Notice("Surfsense: token expired or invalid. Paste a fresh token in settings."); + throw new AuthError(detail || "Unauthorized"); + } + + if (resp.status >= 500 || resp.status === 429) { + throw new TransientError(resp.status, detail || `HTTP ${resp.status}`); + } + + throw new PermanentError(resp.status, detail || `HTTP ${resp.status}`); + } +} + +function parseJson(resp: RequestUrlResponse): T { + if (resp.text === undefined || resp.text === "") return undefined as unknown as T; + try { + return JSON.parse(resp.text) as T; + } catch { + return undefined as unknown as T; + } +} + +function safeJson(resp: RequestUrlResponse): Record { + try { + return resp.text ? (JSON.parse(resp.text) as Record) : {}; + } catch { + return {}; + } +} + +function extractDetail(resp: RequestUrlResponse): string { + const json = safeJson(resp); + if (typeof json.detail === "string") return json.detail; + if (typeof json.message === "string") return json.message; + return resp.text?.slice(0, 200) ?? ""; +} diff --git a/surfsense_obsidian/src/excludes.ts b/surfsense_obsidian/src/excludes.ts new file mode 100644 index 000000000..67a59bc50 --- /dev/null +++ b/surfsense_obsidian/src/excludes.ts @@ -0,0 +1,66 @@ +/** + * Tiny glob matcher for exclude patterns. + * + * Supports `*` (any chars except `/`), `**` (any chars including `/`), and + * literal segments. Patterns without a slash are matched against any path + * segment (so `templates` excludes `templates/foo.md` and `notes/templates/x.md`). + * + * Intentionally not a full minimatch — Obsidian users overwhelmingly type + * folder names ("templates", ".trash") and the obvious wildcards. Avoiding + * the dependency keeps the bundle small and the mobile attack surface tiny. + */ + +const cache = new Map(); + +function compile(pattern: string): RegExp { + const cached = cache.get(pattern); + if (cached) return cached; + + let body = ""; + let i = 0; + while (i < pattern.length) { + const ch = pattern[i] ?? ""; + if (ch === "*") { + if (pattern[i + 1] === "*") { + body += ".*"; + i += 2; + if (pattern[i] === "/") i += 1; + continue; + } + body += "[^/]*"; + i += 1; + continue; + } + if (".+^${}()|[]\\".includes(ch)) { + body += "\\" + ch; + i += 1; + continue; + } + body += ch; + i += 1; + } + + const anchored = pattern.includes("/") + ? `^${body}(/.*)?$` + : `(^|/)${body}(/.*)?$`; + const re = new RegExp(anchored); + cache.set(pattern, re); + return re; +} + +export function isExcluded(path: string, patterns: string[]): boolean { + if (!patterns.length) return false; + for (const raw of patterns) { + const trimmed = raw.trim(); + if (!trimmed || trimmed.startsWith("#")) continue; + if (compile(trimmed).test(path)) return true; + } + return false; +} + +export function parseExcludePatterns(raw: string): string[] { + return raw + .split(/\r?\n/) + .map((line) => line.trim()) + .filter((line) => line.length > 0 && !line.startsWith("#")); +} diff --git a/surfsense_obsidian/src/main.ts b/surfsense_obsidian/src/main.ts index 6fe0c83a8..34e5715a1 100644 --- a/surfsense_obsidian/src/main.ts +++ b/surfsense_obsidian/src/main.ts @@ -1,99 +1,216 @@ -import {App, Editor, MarkdownView, Modal, Notice, Plugin} from 'obsidian'; -import {DEFAULT_SETTINGS, MyPluginSettings, SampleSettingTab} from "./settings"; +import { Notice, Plugin } from "obsidian"; +import { SurfSenseApiClient } from "./api-client"; +import { PersistentQueue } from "./queue"; +import { SurfSenseSettingTab } from "./settings"; +import { StatusBar } from "./status-bar"; +import { SyncEngine } from "./sync-engine"; +import { + DEFAULT_SETTINGS, + type QueueItem, + type StatusState, + type SurfsensePluginSettings, +} from "./types"; -// Remember to rename these classes and interfaces! - -export default class MyPlugin extends Plugin { - settings: MyPluginSettings; +/** + * SurfSense plugin entry point. + * + * Replaces the obsidian-sample-plugin SampleModal/ribbon stub. Lifecycle: + * + * onload(): + * load settings → seed identity (vault_id, device_id) → + * wire api client + queue + sync engine + status bar → + * register settings tab → register vault + metadataCache events → + * register commands (resync, sync current note, open settings) → + * register status bar item → + * kick off engine.start() (health → drain → reconcile). + * + * onunload(): + * stop the queue's debounce timer; unregistered events and DOM + * handles auto-clean via the Plugin base class. + */ +export default class SurfSensePlugin extends Plugin { + settings!: SurfsensePluginSettings; + api!: SurfSenseApiClient; + queue!: PersistentQueue; + engine!: SyncEngine; + private statusBar: StatusBar | null = null; + lastStatus: StatusState = { kind: "idle", queueDepth: 0 }; + serverCapabilities: string[] = []; + serverApiVersion: string | null = null; + private settingTab: SurfSenseSettingTab | null = null; async onload() { await this.loadSettings(); + this.seedIdentity(); + await this.saveSettings(); - // This creates an icon in the left ribbon. - this.addRibbonIcon('dice', 'Sample', (evt: MouseEvent) => { - // Called when the user clicks the icon. - new Notice('This is a notice!'); + const pluginVersion = this.manifest.version; + + this.api = new SurfSenseApiClient({ + getServerUrl: () => this.settings.serverUrl, + getToken: () => this.settings.apiToken, + pluginVersion, }); - // This adds a status bar item to the bottom of the app. Does not work on mobile apps. - const statusBarItemEl = this.addStatusBarItem(); - statusBarItemEl.setText('Status bar text'); - - // This adds a simple command that can be triggered anywhere - this.addCommand({ - id: 'open-modal-simple', - name: 'Open modal (simple)', - callback: () => { - new SampleModal(this.app).open(); - } + this.queue = new PersistentQueue(this.settings.queue ?? [], { + persist: async (items) => { + this.settings.queue = items; + await this.saveData(this.settings); + }, }); - // This adds an editor command that can perform some operation on the current editor instance - this.addCommand({ - id: 'replace-selected', - name: 'Replace selected content', - editorCallback: (editor: Editor, view: MarkdownView) => { - editor.replaceSelection('Sample editor command'); - } - }); - // This adds a complex command that can check whether the current state of the app allows execution of the command - this.addCommand({ - id: 'open-modal-complex', - name: 'Open modal (complex)', - checkCallback: (checking: boolean) => { - // Conditions to check - const markdownView = this.app.workspace.getActiveViewOfType(MarkdownView); - if (markdownView) { - // If checking is true, we're simply "checking" if the command can be run. - // If checking is false, then we want to actually perform the operation. - if (!checking) { - new SampleModal(this.app).open(); - } - // This command will only show up in Command Palette when the check function returns true - return true; + this.engine = new SyncEngine({ + app: this.app, + apiClient: this.api, + queue: this.queue, + getSettings: () => this.settings, + saveSettings: async (mut) => { + mut(this.settings); + await this.saveSettings(); + this.settingTab?.renderStatus(); + }, + setStatus: (s) => { + this.lastStatus = s; + this.statusBar?.update(s); + this.settingTab?.renderStatus(); + }, + onCapabilities: (caps, apiVersion) => { + this.serverCapabilities = [...caps]; + this.serverApiVersion = apiVersion; + this.settingTab?.renderStatus(); + }, + }); + + this.queue.setFlushHandler(() => { + if (this.settings.syncMode !== "auto") return; + void this.engine.flushQueue(); + }); + + this.settingTab = new SurfSenseSettingTab(this.app, this); + this.addSettingTab(this.settingTab); + + const statusHost = this.addStatusBarItem(); + this.statusBar = new StatusBar(statusHost); + this.statusBar.update(this.lastStatus); + + this.registerEvent( + this.app.vault.on("create", (file) => this.engine.onCreate(file)), + ); + this.registerEvent( + this.app.vault.on("modify", (file) => this.engine.onModify(file)), + ); + this.registerEvent( + this.app.vault.on("delete", (file) => this.engine.onDelete(file)), + ); + this.registerEvent( + this.app.vault.on("rename", (file, oldPath) => + this.engine.onRename(file, oldPath), + ), + ); + this.registerEvent( + this.app.metadataCache.on("changed", (file, data, cache) => + this.engine.onMetadataChanged(file, data, cache), + ), + ); + + this.addCommand({ + id: "resync-vault", + name: "Re-sync entire vault", + callback: async () => { + try { + await this.engine.maybeReconcile(true); + new Notice("Surfsense: re-sync started."); + } catch (err) { + new Notice(`Surfsense: re-sync failed — ${(err as Error).message}`); } - return false; - } + }, }); - // This adds a settings tab so the user can configure various aspects of the plugin - this.addSettingTab(new SampleSettingTab(this.app, this)); - - // If the plugin hooks up any global DOM events (on parts of the app that doesn't belong to this plugin) - // Using this function will automatically remove the event listener when this plugin is disabled. - this.registerDomEvent(document, 'click', (evt: MouseEvent) => { - new Notice("Click"); + this.addCommand({ + id: "sync-current-note", + name: "Sync current note", + checkCallback: (checking) => { + const file = this.app.workspace.getActiveFile(); + if (!file || file.extension.toLowerCase() !== "md") return false; + if (checking) return true; + this.queue.enqueueUpsert(file.path); + void this.engine.flushQueue(); + return true; + }, }); - // When registering intervals, this function will automatically clear the interval when the plugin is disabled. - this.registerInterval(window.setInterval(() => console.log('setInterval'), 5 * 60 * 1000)); + this.addCommand({ + id: "open-settings", + name: "Open settings", + callback: () => { + // Obsidian exposes this through the Setting host on the workspace; + // fall back silently if the API moves so we never throw. + type SettingHost = { + open?: () => void; + openTabById?: (id: string) => void; + }; + const setting = (this.app as unknown as { setting?: SettingHost }).setting; + if (setting?.open) setting.open(); + if (setting?.openTabById) setting.openTabById(this.manifest.id); + }, + }); + // Kick off the start sequence after Obsidian finishes its own + // startup work, so the metadataCache is warm before reconcile. + this.app.workspace.onLayoutReady(() => { + void this.engine.start(); + }); } onunload() { + this.queue?.cancelFlush(); + this.queue?.requestStop(); + } + + get queueDepth(): number { + return this.queue?.size ?? 0; } async loadSettings() { - this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData() as Partial); + const data = (await this.loadData()) as Partial | null; + this.settings = { + ...DEFAULT_SETTINGS, + ...(data ?? {}), + queue: (data?.queue ?? []).map((i: QueueItem) => ({ ...i })), + tombstones: { ...(data?.tombstones ?? {}) }, + excludePatterns: data?.excludePatterns?.length + ? [...data.excludePatterns] + : [...DEFAULT_SETTINGS.excludePatterns], + }; } async saveSettings() { await this.saveData(this.settings); } -} -class SampleModal extends Modal { - constructor(app: App) { - super(app); - } - - onOpen() { - let {contentEl} = this; - contentEl.setText('Woah!'); - } - - onClose() { - const {contentEl} = this; - contentEl.empty(); + private seedIdentity(): void { + if (!this.settings.vaultId) { + this.settings.vaultId = generateUuid(); + } + if (!this.settings.deviceId) { + this.settings.deviceId = generateUuid(); + } + if (!this.settings.vaultName) { + this.settings.vaultName = this.app.vault.getName(); + } } } + +function generateUuid(): string { + const c = globalThis.crypto; + if (c?.randomUUID) return c.randomUUID(); + const buf = new Uint8Array(16); + c.getRandomValues(buf); + buf[6] = ((buf[6] ?? 0) & 0x0f) | 0x40; + buf[8] = ((buf[8] ?? 0) & 0x3f) | 0x80; + const hex = Array.from(buf, (b) => b.toString(16).padStart(2, "0")).join(""); + return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice( + 16, + 20, + )}-${hex.slice(20)}`; +} diff --git a/surfsense_obsidian/src/payload.ts b/surfsense_obsidian/src/payload.ts new file mode 100644 index 000000000..86b889f89 --- /dev/null +++ b/surfsense_obsidian/src/payload.ts @@ -0,0 +1,162 @@ +import { + type App, + type CachedMetadata, + type FrontMatterCache, + type HeadingCache, + type ReferenceCache, + type TFile, +} from "obsidian"; +import type { HeadingRef, NotePayload } from "./types"; + +/** + * Build a NotePayload from an Obsidian TFile. + * + * Mobile-safety contract: + * - No top-level `node:fs` / `node:path` / `node:crypto` imports. + * File IO uses `vault.cachedRead` (works on the mobile WASM adapter). + * Hashing uses Web Crypto `subtle.digest`. + * - Caller MUST first wait for `metadataCache.changed` before calling + * this for a `.md` file, otherwise `frontmatter`/`tags`/`headings` + * can lag the actual file contents. + */ +export async function buildNotePayload( + app: App, + file: TFile, + vaultId: string, +): Promise { + const content = await app.vault.cachedRead(file); + const cache: CachedMetadata | null = app.metadataCache.getFileCache(file); + + const frontmatter = normalizeFrontmatter(cache?.frontmatter); + const tags = collectTags(cache); + const headings = collectHeadings(cache?.headings ?? []); + const aliases = collectAliases(frontmatter); + const { embeds, internalLinks } = collectLinks(cache); + const { resolved, unresolved } = resolveLinkTargets( + app, + file.path, + internalLinks, + ); + const contentHash = await computeContentHash(content); + + return { + vault_id: vaultId, + path: file.path, + name: file.basename, + extension: file.extension, + content, + frontmatter, + tags, + headings, + resolved_links: resolved, + unresolved_links: unresolved, + embeds, + aliases, + content_hash: contentHash, + mtime: file.stat.mtime, + ctime: file.stat.ctime, + }; +} + +export async function computeContentHash(content: string): Promise { + const bytes = new TextEncoder().encode(content); + const digest = await crypto.subtle.digest("SHA-256", bytes); + return bufferToHex(digest); +} + +function bufferToHex(buf: ArrayBuffer): string { + const view = new Uint8Array(buf); + let hex = ""; + for (let i = 0; i < view.length; i++) { + hex += (view[i] ?? 0).toString(16).padStart(2, "0"); + } + return hex; +} + +function normalizeFrontmatter( + fm: FrontMatterCache | undefined, +): Record { + if (!fm) return {}; + // FrontMatterCache extends a plain object; strip the `position` key + // the cache adds so the wire payload stays clean. + const rest: Record = { ...(fm as Record) }; + delete rest.position; + return rest; +} + +function collectTags(cache: CachedMetadata | null): string[] { + const out = new Set(); + for (const t of cache?.tags ?? []) { + const tag = t.tag.startsWith("#") ? t.tag.slice(1) : t.tag; + if (tag) out.add(tag); + } + const fmTags: unknown = + cache?.frontmatter?.tags ?? cache?.frontmatter?.tag; + if (Array.isArray(fmTags)) { + for (const t of fmTags) { + if (typeof t === "string" && t) out.add(t.replace(/^#/, "")); + } + } else if (typeof fmTags === "string" && fmTags) { + for (const t of fmTags.split(/[\s,]+/)) { + if (t) out.add(t.replace(/^#/, "")); + } + } + return [...out]; +} + +function collectHeadings(items: HeadingCache[]): HeadingRef[] { + return items.map((h) => ({ heading: h.heading, level: h.level })); +} + +function collectAliases(frontmatter: Record): string[] { + const raw = frontmatter.aliases ?? frontmatter.alias; + if (Array.isArray(raw)) { + return raw.filter((x): x is string => typeof x === "string" && x.length > 0); + } + if (typeof raw === "string" && raw) return [raw]; + return []; +} + +function collectLinks(cache: CachedMetadata | null): { + embeds: string[]; + internalLinks: ReferenceCache[]; +} { + const linkRefs: ReferenceCache[] = [ + ...((cache?.links) ?? []), + ...((cache?.embeds as ReferenceCache[] | undefined) ?? []), + ]; + const embeds = ((cache?.embeds as ReferenceCache[] | undefined) ?? []).map( + (e) => e.link, + ); + return { embeds, internalLinks: linkRefs }; +} + +function resolveLinkTargets( + app: App, + sourcePath: string, + links: ReferenceCache[], +): { resolved: string[]; unresolved: string[] } { + const resolved = new Set(); + const unresolved = new Set(); + for (const link of links) { + const target = app.metadataCache.getFirstLinkpathDest( + stripSubpath(link.link), + sourcePath, + ); + if (target) { + resolved.add(target.path); + } else { + unresolved.add(link.link); + } + } + return { resolved: [...resolved], unresolved: [...unresolved] }; +} + +function stripSubpath(link: string): string { + const hashIdx = link.indexOf("#"); + const pipeIdx = link.indexOf("|"); + let end = link.length; + if (hashIdx !== -1) end = Math.min(end, hashIdx); + if (pipeIdx !== -1) end = Math.min(end, pipeIdx); + return link.slice(0, end); +} diff --git a/surfsense_obsidian/src/queue.ts b/surfsense_obsidian/src/queue.ts new file mode 100644 index 000000000..9636da81c --- /dev/null +++ b/surfsense_obsidian/src/queue.ts @@ -0,0 +1,237 @@ +import type { QueueItem } from "./types"; + +/** + * Persistent upload queue. + * + * Mobile-safety contract: + * - Persistence is delegated to a save callback (which the plugin wires + * to `plugin.saveData()`); never `node:fs`. Items also live in the + * plugin's settings JSON so a crash mid-flight loses nothing. + * - No top-level `node:*` imports. + * + * Behavioural contract: + * - Per-file debounce: enqueueing the same path coalesces, the latest + * `enqueuedAt` wins so we don't ship a stale snapshot. + * - `delete` for a path drops any pending `upsert` for that path + * (otherwise we'd resurrect a note the user just deleted). + * - `rename` is a first-class op so the backend can update + * `unique_identifier_hash` instead of "delete + create" (which would + * blow away document versions, citations, and the document_id used + * in chat history). + * - Drain takes a worker, returns once the worker either succeeds for + * every batch or hits a stop signal (transient error, mid-drain + * stop request). + */ + +export interface QueueWorker { + processBatch(batch: QueueItem[]): Promise; +} + +export interface BatchResult { + /** Items that succeeded; they will be ack'd off the queue. */ + acked: QueueItem[]; + /** Items that should be retried; their `attempt` is bumped. */ + retry: QueueItem[]; + /** Items that failed permanently (4xx). They get dropped. */ + dropped: QueueItem[]; + /** If true, the drain loop stops (e.g. transient/network error). */ + stop: boolean; + /** Optional retry-after for transient errors (ms). */ + backoffMs?: number; +} + +export interface PersistentQueueOptions { + debounceMs?: number; + batchSize?: number; + maxAttempts?: number; + persist: (items: QueueItem[]) => Promise | void; + now?: () => number; +} + +const DEFAULTS = { + debounceMs: 2000, + batchSize: 15, + maxAttempts: 8, +}; + +export class PersistentQueue { + private items: QueueItem[]; + private readonly opts: Required< + Omit + > & { + persist: PersistentQueueOptions["persist"]; + now: () => number; + }; + private draining = false; + private stopRequested = false; + private flushTimer: ReturnType | null = null; + private onFlush: (() => void) | null = null; + + constructor(initial: QueueItem[], opts: PersistentQueueOptions) { + this.items = [...initial]; + this.opts = { + debounceMs: opts.debounceMs ?? DEFAULTS.debounceMs, + batchSize: opts.batchSize ?? DEFAULTS.batchSize, + maxAttempts: opts.maxAttempts ?? DEFAULTS.maxAttempts, + persist: opts.persist, + now: opts.now ?? (() => Date.now()), + }; + } + + get size(): number { + return this.items.length; + } + + snapshot(): QueueItem[] { + return this.items.map((i) => ({ ...i })); + } + + setFlushHandler(handler: () => void): void { + this.onFlush = handler; + } + + enqueueUpsert(path: string): void { + const now = this.opts.now(); + this.items = this.items.filter( + (i) => !(i.op === "upsert" && i.path === path), + ); + this.items.push({ op: "upsert", path, enqueuedAt: now, attempt: 0 }); + void this.persist(); + this.scheduleFlush(); + } + + enqueueDelete(path: string): void { + const now = this.opts.now(); + // A delete supersedes any pending upsert for the same path. + this.items = this.items.filter( + (i) => + !( + (i.op === "upsert" && i.path === path) || + (i.op === "delete" && i.path === path) + ), + ); + this.items.push({ op: "delete", path, enqueuedAt: now, attempt: 0 }); + void this.persist(); + this.scheduleFlush(); + } + + enqueueRename(oldPath: string, newPath: string): void { + const now = this.opts.now(); + this.items = this.items.filter( + (i) => + !( + (i.op === "upsert" && (i.path === oldPath || i.path === newPath)) || + (i.op === "rename" && i.oldPath === oldPath && i.newPath === newPath) + ), + ); + this.items.push({ + op: "rename", + oldPath, + newPath, + enqueuedAt: now, + attempt: 0, + }); + // Also enqueue an upsert of the new path so its content/metadata + // reflects whatever the editor flushed alongside the rename. + this.items.push({ op: "upsert", path: newPath, enqueuedAt: now, attempt: 0 }); + void this.persist(); + this.scheduleFlush(); + } + + requestStop(): void { + this.stopRequested = true; + } + + cancelFlush(): void { + if (this.flushTimer !== null) { + clearTimeout(this.flushTimer); + this.flushTimer = null; + } + } + + private scheduleFlush(): void { + if (!this.onFlush) return; + if (this.flushTimer !== null) clearTimeout(this.flushTimer); + this.flushTimer = setTimeout(() => { + this.flushTimer = null; + this.onFlush?.(); + }, this.opts.debounceMs); + } + + async drain(worker: QueueWorker): Promise { + if (this.draining) return { batches: 0, acked: 0, dropped: 0, stopped: false }; + this.draining = true; + this.stopRequested = false; + const summary: DrainSummary = { + batches: 0, + acked: 0, + dropped: 0, + stopped: false, + }; + try { + while (this.items.length > 0 && !this.stopRequested) { + const batch = this.takeBatch(); + summary.batches += 1; + + const result = await worker.processBatch(batch); + summary.acked += result.acked.length; + summary.dropped += result.dropped.length; + + const ackKeys = new Set(result.acked.map(itemKey)); + const dropKeys = new Set(result.dropped.map(itemKey)); + const retryKeys = new Set(result.retry.map(itemKey)); + + // Keep any item we didn't explicitly account for in `retry` + // so a partial-batch drop never silently loses work. + const unhandled = batch.filter( + (b) => + !ackKeys.has(itemKey(b)) && + !dropKeys.has(itemKey(b)) && + !retryKeys.has(itemKey(b)), + ); + const retry = [...result.retry, ...unhandled].map((i) => ({ + ...i, + attempt: i.attempt + 1, + })); + const survivors = retry.filter((i) => i.attempt <= this.opts.maxAttempts); + summary.dropped += retry.length - survivors.length; + + this.items = [...survivors, ...this.items]; + await this.persist(); + + if (result.stop) { + summary.stopped = true; + if (result.backoffMs) summary.backoffMs = result.backoffMs; + break; + } + } + if (this.stopRequested) summary.stopped = true; + return summary; + } finally { + this.draining = false; + } + } + + private takeBatch(): QueueItem[] { + const head = this.items.slice(0, this.opts.batchSize); + this.items = this.items.slice(this.opts.batchSize); + return head; + } + + private async persist(): Promise { + await this.opts.persist(this.snapshot()); + } +} + +export interface DrainSummary { + batches: number; + acked: number; + dropped: number; + stopped: boolean; + backoffMs?: number; +} + +export function itemKey(i: QueueItem): string { + if (i.op === "rename") return `rename:${i.oldPath}=>${i.newPath}`; + return `${i.op}:${i.path}`; +} diff --git a/surfsense_obsidian/src/settings.ts b/surfsense_obsidian/src/settings.ts index 352121e07..d22b66384 100644 --- a/surfsense_obsidian/src/settings.ts +++ b/surfsense_obsidian/src/settings.ts @@ -1,36 +1,322 @@ -import {App, PluginSettingTab, Setting} from "obsidian"; -import MyPlugin from "./main"; +import { + type App, + Notice, + PluginSettingTab, + Setting, +} from "obsidian"; +import { AuthError } from "./api-client"; +import { parseExcludePatterns } from "./excludes"; +import type SurfSensePlugin from "./main"; +import type { SearchSpace } from "./types"; -export interface MyPluginSettings { - mySetting: string; -} +/** + * Plugin settings tab. + * + * Replaces the obsidian-sample-plugin SampleSettingTab stub. Same module + * path so existing imports from main.ts keep resolving. + * + * Surface mirrors the per-plan list: + * server URL · api token · search space · vault name · sync mode · + * exclude patterns · include attachments · status panel. + * + * Vault id, device id, and device label are auto-generated UUIDs the + * first time settings load — they're displayed (read-only) so users can + * audit them, but never editable. Vault id is decoupled from the OS + * folder name so renaming the vault doesn't invalidate the connector + * (edge case #5 from the plan). + */ -export const DEFAULT_SETTINGS: MyPluginSettings = { - mySetting: 'default' -} +export class SurfSenseSettingTab extends PluginSettingTab { + private readonly plugin: SurfSensePlugin; + private searchSpaces: SearchSpace[] = []; + private loadingSpaces = false; + private statusEl: HTMLElement | null = null; -export class SampleSettingTab extends PluginSettingTab { - plugin: MyPlugin; - - constructor(app: App, plugin: MyPlugin) { + constructor(app: App, plugin: SurfSensePlugin) { super(app, plugin); this.plugin = plugin; } display(): void { - const {containerEl} = this; - + const { containerEl } = this; containerEl.empty(); + containerEl.addClass("surfsense-settings"); + + const settings = this.plugin.settings; + + new Setting(containerEl).setName("Connection").setHeading(); new Setting(containerEl) - .setName('Settings #1') - .setDesc('It\'s a secret') - .addText(text => text - .setPlaceholder('Enter your secret') - .setValue(this.plugin.settings.mySetting) - .onChange(async (value) => { - this.plugin.settings.mySetting = value; + .setName("Server URL") + .setDesc( + "https://api.surfsense.com for SurfSense Cloud, or your self-hosted URL.", + ) + .addText((text) => + text + .setPlaceholder("https://api.surfsense.com") + .setValue(settings.serverUrl) + .onChange(async (value) => { + this.plugin.settings.serverUrl = value.trim(); + await this.plugin.saveSettings(); + }), + ); + + new Setting(containerEl) + .setName("API token") + .setDesc( + "Paste your Surfsense API token (expires after 24 hours; re-paste when you see an auth error).", + ) + .addText((text) => { + text.inputEl.type = "password"; + text.inputEl.autocomplete = "off"; + text.inputEl.spellcheck = false; + text + .setPlaceholder("Paste token") + .setValue(settings.apiToken) + .onChange(async (value) => { + this.plugin.settings.apiToken = value.trim(); + await this.plugin.saveSettings(); + }); + }) + .addButton((btn) => + btn + .setButtonText("Verify") + .setCta() + .onClick(async () => { + btn.setDisabled(true); + try { + await this.plugin.api.verifyToken(); + new Notice("Surfsense: token verified."); + await this.refreshSearchSpaces(); + this.display(); + } catch (err) { + this.handleApiError(err); + } finally { + btn.setDisabled(false); + } + }), + ); + + new Setting(containerEl) + .setName("Search space") + .setDesc( + "Which Surfsense search space this vault syncs into. Reload after changing your token.", + ) + .addDropdown((drop) => { + drop.addOption("", this.loadingSpaces ? "Loading…" : "Select a search space"); + for (const space of this.searchSpaces) { + drop.addOption(String(space.id), space.name); + } + if (settings.searchSpaceId !== null) { + drop.setValue(String(settings.searchSpaceId)); + } + drop.onChange(async (value) => { + this.plugin.settings.searchSpaceId = value ? Number(value) : null; + this.plugin.settings.connectorId = null; await this.plugin.saveSettings(); - })); + if (this.plugin.settings.searchSpaceId !== null) { + try { + await this.plugin.engine.ensureConnected(); + new Notice("Surfsense: vault connected."); + } catch (err) { + this.handleApiError(err); + } + } + this.renderStatus(); + }); + }) + .addExtraButton((btn) => + btn + .setIcon("refresh-ccw") + .setTooltip("Reload search spaces") + .onClick(async () => { + await this.refreshSearchSpaces(); + this.display(); + }), + ); + + new Setting(containerEl).setName("Vault").setHeading(); + + new Setting(containerEl) + .setName("Vault name") + .setDesc( + "Friendly name for this vault. Defaults to your Obsidian vault folder name.", + ) + .addText((text) => + text + .setValue(settings.vaultName) + .onChange(async (value) => { + this.plugin.settings.vaultName = value.trim() || this.app.vault.getName(); + await this.plugin.saveSettings(); + }), + ); + + new Setting(containerEl) + .setName("Device label") + .setDesc( + "Optional human-readable label shown next to the device ID in the Surfsense web app.", + ) + .addText((text) => + text + .setPlaceholder("My laptop") + .setValue(settings.deviceLabel) + .onChange(async (value) => { + this.plugin.settings.deviceLabel = value.trim(); + await this.plugin.saveSettings(); + }), + ); + + new Setting(containerEl) + .setName("Sync mode") + .setDesc("Auto syncs on every edit. Manual only syncs when you trigger it via the command palette.") + .addDropdown((drop) => + drop + .addOption("auto", "Auto") + .addOption("manual", "Manual") + .setValue(settings.syncMode) + .onChange(async (value) => { + this.plugin.settings.syncMode = value === "manual" ? "manual" : "auto"; + await this.plugin.saveSettings(); + }), + ); + + new Setting(containerEl) + .setName("Exclude patterns") + .setDesc( + "One pattern per line. Supports * and **. Lines starting with # are comments. Files matching any pattern are skipped.", + ) + .addTextArea((area) => { + area.inputEl.rows = 4; + area + .setPlaceholder(".trash\n_attachments\ntemplates/**") + .setValue(settings.excludePatterns.join("\n")) + .onChange(async (value) => { + this.plugin.settings.excludePatterns = parseExcludePatterns(value); + await this.plugin.saveSettings(); + }); + }); + + new Setting(containerEl) + .setName("Include attachments") + .setDesc( + "Sync non-Markdown files (images, PDFs, …). Off by default — Markdown only.", + ) + .addToggle((toggle) => + toggle + .setValue(settings.includeAttachments) + .onChange(async (value) => { + this.plugin.settings.includeAttachments = value; + await this.plugin.saveSettings(); + }), + ); + + new Setting(containerEl).setName("Identity").setHeading(); + + new Setting(containerEl) + .setName("Vault ID") + .setDesc("Stable identifier for this vault. Used by the backend to keep separate vaults distinct even if their folder names change.") + .addText((text) => { + text.inputEl.disabled = true; + text.setValue(settings.vaultId); + }); + + new Setting(containerEl) + .setName("Device ID") + .setDesc("Stable identifier for this install. Used by the backend so you can revoke a single device without disconnecting the others.") + .addText((text) => { + text.inputEl.disabled = true; + text.setValue(settings.deviceId); + }); + + new Setting(containerEl).setName("Status").setHeading(); + this.statusEl = containerEl.createDiv({ cls: "surfsense-settings__status" }); + this.renderStatus(); + + new Setting(containerEl) + .addButton((btn) => + btn + .setButtonText("Re-sync entire vault") + .onClick(async () => { + btn.setDisabled(true); + try { + await this.plugin.engine.maybeReconcile(true); + new Notice("Surfsense: re-sync requested."); + } catch (err) { + this.handleApiError(err); + } finally { + btn.setDisabled(false); + this.renderStatus(); + } + }), + ) + .addButton((btn) => + btn.setButtonText("Open releases").onClick(() => { + window.open( + "https://github.com/MODSetter/SurfSense/releases?q=obsidian", + "_blank", + ); + }), + ); + } + + hide(): void { + this.statusEl = null; + } + + private async refreshSearchSpaces(): Promise { + this.loadingSpaces = true; + try { + this.searchSpaces = await this.plugin.api.listSearchSpaces(); + } catch (err) { + this.handleApiError(err); + this.searchSpaces = []; + } finally { + this.loadingSpaces = false; + } + } + + renderStatus(): void { + if (!this.statusEl) return; + const s = this.plugin.settings; + this.statusEl.empty(); + + const rows: { label: string; value: string }[] = [ + { label: "Status", value: this.plugin.lastStatus.kind }, + { + label: "Last sync", + value: s.lastSyncAt ? new Date(s.lastSyncAt).toLocaleString() : "—", + }, + { + label: "Last reconcile", + value: s.lastReconcileAt ? new Date(s.lastReconcileAt).toLocaleString() : "—", + }, + { label: "Files synced", value: String(s.filesSynced ?? 0) }, + { label: "Queue depth", value: String(this.plugin.queueDepth) }, + { + label: "API version", + value: this.plugin.serverApiVersion ?? "(not yet handshaken)", + }, + { + label: "Capabilities", + value: this.plugin.serverCapabilities.length + ? this.plugin.serverCapabilities.join(", ") + : "(not yet handshaken)", + }, + ]; + for (const row of rows) { + const wrap = this.statusEl.createDiv({ cls: "surfsense-settings__status-row" }); + wrap.createSpan({ cls: "surfsense-settings__status-label", text: row.label }); + wrap.createSpan({ cls: "surfsense-settings__status-value", text: row.value }); + } + } + + private handleApiError(err: unknown): void { + if (err instanceof AuthError) { + new Notice(`SurfSense: ${err.message}`); + return; + } + new Notice( + `SurfSense: request failed — ${(err as Error).message ?? "unknown error"}`, + ); } } diff --git a/surfsense_obsidian/src/status-bar.ts b/surfsense_obsidian/src/status-bar.ts new file mode 100644 index 000000000..4dc163778 --- /dev/null +++ b/surfsense_obsidian/src/status-bar.ts @@ -0,0 +1,61 @@ +import { setIcon } from "obsidian"; +import type { StatusKind, StatusState } from "./types"; + +/** + * Tiny status-bar adornment. + * + * Plain DOM (no HTML strings, no CSS-in-JS) so it stays cheap on mobile + * and Obsidian's lint doesn't complain about innerHTML. + */ + +interface StatusVisual { + icon: string; + label: string; + cls: string; +} + +const VISUALS: Record = { + idle: { icon: "check-circle", label: "Synced", cls: "surfsense-status--ok" }, + syncing: { icon: "refresh-ccw", label: "Syncing", cls: "surfsense-status--syncing" }, + queued: { icon: "upload", label: "Queued", cls: "surfsense-status--syncing" }, + offline: { icon: "wifi-off", label: "Offline", cls: "surfsense-status--warn" }, + "auth-error": { icon: "lock", label: "Auth error", cls: "surfsense-status--err" }, + error: { icon: "alert-circle", label: "Error", cls: "surfsense-status--err" }, +}; + +export class StatusBar { + private readonly el: HTMLElement; + private readonly icon: HTMLElement; + private readonly text: HTMLElement; + + constructor(host: HTMLElement) { + this.el = host; + this.el.addClass("surfsense-status"); + this.icon = this.el.createSpan({ cls: "surfsense-status__icon" }); + this.text = this.el.createSpan({ cls: "surfsense-status__text" }); + this.update({ kind: "idle", queueDepth: 0 }); + } + + update(state: StatusState): void { + const visual = VISUALS[state.kind]; + this.el.removeClass( + "surfsense-status--ok", + "surfsense-status--syncing", + "surfsense-status--warn", + "surfsense-status--err", + ); + this.el.addClass(visual.cls); + setIcon(this.icon, visual.icon); + + let label = `SurfSense: ${visual.label}`; + if (state.queueDepth > 0 && state.kind !== "idle") { + label += ` (${state.queueDepth})`; + } + this.text.setText(label); + this.el.setAttr( + "aria-label", + state.detail ? `${label} — ${state.detail}` : label, + ); + this.el.setAttr("title", state.detail ?? label); + } +} diff --git a/surfsense_obsidian/src/sync-engine.ts b/surfsense_obsidian/src/sync-engine.ts new file mode 100644 index 000000000..ce22b69c1 --- /dev/null +++ b/surfsense_obsidian/src/sync-engine.ts @@ -0,0 +1,505 @@ +import { Notice, TFile, type App, type CachedMetadata, type TAbstractFile } from "obsidian"; +import { + AuthError, + PermanentError, + type SurfSenseApiClient, + TransientError, +} from "./api-client"; +import { isExcluded } from "./excludes"; +import { buildNotePayload, computeContentHash } from "./payload"; +import { type BatchResult, PersistentQueue } from "./queue"; +import type { + HealthResponse, + NotePayload, + QueueItem, + StatusKind, + StatusState, +} from "./types"; + +/** + * Owner of "what does the vault look like vs the server" reasoning. + * + * Onload sequence (per plan §p4_plugin_sync_engine, in this exact order): + * 1. apiClient.health() — proves connectivity and pulls the capabilities + * handshake before we issue any sync traffic. + * 2. Cache health.capabilities + api_version on the plugin instance + * so feature gating (e.g. "attachments_v2" before syncing binaries) + * reads from local state instead of round-tripping. + * 3. Drain queue — items persisted from the previous session land first. + * 4. Reconcile — GET /manifest, diff against vault, queue uploads/deletes. + * 5. Subscribe events — only after the above so the user's first edit + * after launching Obsidian doesn't race with the manifest diff. + * + * Reconcile skips itself if last successful reconcile is < RECONCILE_MIN_INTERVAL_MS + * ago. ConnectResponse already carries handshake fields so first connect + * does not need a separate /health round-trip. + */ + +export interface SyncEngineDeps { + app: App; + apiClient: SurfSenseApiClient; + queue: PersistentQueue; + getSettings: () => SyncEngineSettings; + saveSettings: (mut: (s: SyncEngineSettings) => void) => Promise; + setStatus: (s: StatusState) => void; + onCapabilities: (caps: string[], apiVersion: string) => void; +} + +export interface SyncEngineSettings { + vaultId: string; + vaultName: string; + connectorId: number | null; + searchSpaceId: number | null; + deviceId: string; + deviceLabel: string; + excludePatterns: string[]; + includeAttachments: boolean; + syncMode: "auto" | "manual"; + lastReconcileAt: number | null; + lastSyncAt: number | null; + filesSynced: number; + tombstones: Record; +} + +export const RECONCILE_MIN_INTERVAL_MS = 5 * 60 * 1000; +const TOMBSTONE_TTL_MS = 24 * 60 * 60 * 1000; // 1 day +const PENDING_DEBOUNCE_MS = 1500; + +export class SyncEngine { + private readonly deps: SyncEngineDeps; + private capabilities: string[] = []; + private apiVersion: string | null = null; + private pendingMdEdits = new Map>(); + + constructor(deps: SyncEngineDeps) { + this.deps = deps; + } + + getCapabilities(): readonly string[] { + return this.capabilities; + } + + supports(capability: string): boolean { + return this.capabilities.includes(capability); + } + + /** Run the onload sequence described in this file's docstring. */ + async start(): Promise { + this.setStatus("syncing", "Connecting to SurfSense…"); + try { + const health = await this.deps.apiClient.health(); + this.applyHealth(health); + } catch (err) { + this.handleStartupError(err); + return; + } + + const settings = this.deps.getSettings(); + if (!settings.connectorId || !settings.searchSpaceId) { + // No connector yet — settings tab will trigger ensureConnect once + // the user picks a search space, then re-call start(). + this.setStatus("idle", "Pick a search space in settings to start syncing."); + return; + } + + await this.flushQueue(); + await this.maybeReconcile(); + this.setStatus(this.queueStatusKind(), undefined); + } + + /** Public entry point used after settings save to (re)connect the vault. */ + async ensureConnected(): Promise { + const settings = this.deps.getSettings(); + if (!settings.searchSpaceId) { + this.setStatus("idle", "Pick a search space in settings."); + return; + } + try { + const resp = await this.deps.apiClient.connect({ + searchSpaceId: settings.searchSpaceId, + vaultId: settings.vaultId, + vaultName: settings.vaultName, + deviceId: settings.deviceId, + deviceLabel: settings.deviceLabel, + }); + this.applyHealth(resp); + await this.deps.saveSettings((s) => { + s.connectorId = resp.connector_id; + }); + } catch (err) { + this.handleStartupError(err); + } + } + + applyHealth(h: HealthResponse): void { + this.capabilities = Array.isArray(h.capabilities) ? [...h.capabilities] : []; + this.apiVersion = h.api_version ?? null; + this.deps.onCapabilities(this.capabilities, this.apiVersion ?? "?"); + } + + // ---- vault event handlers -------------------------------------------- + + onCreate(file: TAbstractFile): void { + if (!this.shouldTrack(file)) return; + const settings = this.deps.getSettings(); + if (this.isExcluded(file.path, settings)) return; + if (this.isMarkdown(file)) { + this.scheduleMdUpsert(file.path); + return; + } + this.deps.queue.enqueueUpsert(file.path); + } + + onModify(file: TAbstractFile): void { + if (!this.shouldTrack(file)) return; + const settings = this.deps.getSettings(); + if (this.isExcluded(file.path, settings)) return; + if (this.isMarkdown(file)) { + // Defer to metadataCache.changed so payload fields are fresh. + this.scheduleMdUpsert(file.path); + return; + } + this.deps.queue.enqueueUpsert(file.path); + } + + onDelete(file: TAbstractFile): void { + if (!this.shouldTrack(file)) return; + this.deps.queue.enqueueDelete(file.path); + void this.deps.saveSettings((s) => { + s.tombstones[file.path] = Date.now(); + }); + } + + onRename(file: TAbstractFile, oldPath: string): void { + if (!this.shouldTrack(file)) return; + const settings = this.deps.getSettings(); + if (this.isExcluded(file.path, settings)) { + this.deps.queue.enqueueDelete(oldPath); + void this.deps.saveSettings((s) => { + s.tombstones[oldPath] = Date.now(); + }); + return; + } + this.deps.queue.enqueueRename(oldPath, file.path); + } + + onMetadataChanged(file: TFile, _data: string, _cache: CachedMetadata): void { + if (!this.shouldTrack(file)) return; + const settings = this.deps.getSettings(); + if (this.isExcluded(file.path, settings)) return; + if (!this.isMarkdown(file)) return; + // Cancel any deferred upsert and enqueue with fresh metadata now. + const pending = this.pendingMdEdits.get(file.path); + if (pending) { + clearTimeout(pending); + this.pendingMdEdits.delete(file.path); + } + this.deps.queue.enqueueUpsert(file.path); + } + + private scheduleMdUpsert(path: string): void { + const existing = this.pendingMdEdits.get(path); + if (existing) clearTimeout(existing); + this.pendingMdEdits.set( + path, + setTimeout(() => { + this.pendingMdEdits.delete(path); + this.deps.queue.enqueueUpsert(path); + }, PENDING_DEBOUNCE_MS), + ); + } + + // ---- queue draining --------------------------------------------------- + + async flushQueue(): Promise { + if (this.deps.queue.size === 0) return; + this.setStatus("syncing", `Syncing ${this.deps.queue.size} item(s)…`); + const summary = await this.deps.queue.drain({ + processBatch: (batch) => this.processBatch(batch), + }); + if (summary.acked > 0) { + await this.deps.saveSettings((s) => { + s.lastSyncAt = Date.now(); + s.filesSynced = (s.filesSynced ?? 0) + summary.acked; + }); + } + this.setStatus(this.queueStatusKind(), this.statusDetail()); + } + + private async processBatch(batch: QueueItem[]): Promise { + const settings = this.deps.getSettings(); + const upserts = batch.filter((b): b is QueueItem & { op: "upsert" } => b.op === "upsert"); + const renames = batch.filter((b): b is QueueItem & { op: "rename" } => b.op === "rename"); + const deletes = batch.filter((b): b is QueueItem & { op: "delete" } => b.op === "delete"); + + const acked: QueueItem[] = []; + const retry: QueueItem[] = []; + const dropped: QueueItem[] = []; + + // Renames first so paths line up server-side before content upserts. + if (renames.length > 0) { + try { + await this.deps.apiClient.renameBatch({ + vaultId: settings.vaultId, + renames: renames.map((r) => ({ oldPath: r.oldPath, newPath: r.newPath })), + }); + acked.push(...renames); + } catch (err) { + const verdict = this.classify(err); + if (verdict === "stop") return { acked, retry: [...retry, ...renames], dropped, stop: true }; + if (verdict === "retry") retry.push(...renames); + else dropped.push(...renames); + } + } + + if (deletes.length > 0) { + try { + await this.deps.apiClient.deleteBatch({ + vaultId: settings.vaultId, + paths: deletes.map((d) => d.path), + }); + acked.push(...deletes); + } catch (err) { + const verdict = this.classify(err); + if (verdict === "stop") return { acked, retry: [...retry, ...deletes], dropped, stop: true }; + if (verdict === "retry") retry.push(...deletes); + else dropped.push(...deletes); + } + } + + if (upserts.length > 0) { + const payloads: NotePayload[] = []; + for (const item of upserts) { + const file = this.deps.app.vault.getAbstractFileByPath(item.path); + if (!file || !isTFile(file)) { + // File vanished; treat as ack (delete will follow if user removed it). + acked.push(item); + continue; + } + try { + const payload = this.isMarkdown(file) + ? await buildNotePayload(this.deps.app, file, settings.vaultId) + : await this.buildBinaryPayload(file, settings.vaultId); + payloads.push(payload); + } catch (err) { + console.error("SurfSense: failed to build payload", item.path, err); + retry.push(item); + } + } + + if (payloads.length > 0) { + try { + const resp = await this.deps.apiClient.syncBatch({ + vaultId: settings.vaultId, + notes: payloads, + }); + const rejected = new Set(resp.rejected ?? []); + for (const item of upserts) { + if (retry.find((r) => r === item)) continue; + if (rejected.has(item.path)) dropped.push(item); + else acked.push(item); + } + } catch (err) { + const verdict = this.classify(err); + if (verdict === "stop") + return { acked, retry: [...retry, ...upserts], dropped, stop: true }; + if (verdict === "retry") retry.push(...upserts); + else dropped.push(...upserts); + } + } + } + + return { acked, retry, dropped, stop: false }; + } + + private async buildBinaryPayload(file: TFile, vaultId: string): Promise { + // Plain attachments don't go through buildNotePayload (no markdown + // metadata to extract). We still need a stable hash + file stat so + // the backend can de-dupe and the manifest diff still works. + const buf = await this.deps.app.vault.readBinary(file); + const digest = await crypto.subtle.digest("SHA-256", buf); + const hash = bufferToHex(digest); + return { + vault_id: vaultId, + path: file.path, + name: file.basename, + extension: file.extension, + content: "", + frontmatter: {}, + tags: [], + headings: [], + resolved_links: [], + unresolved_links: [], + embeds: [], + aliases: [], + content_hash: hash, + mtime: file.stat.mtime, + ctime: file.stat.ctime, + is_binary: true, + }; + } + + // ---- reconcile -------------------------------------------------------- + + async maybeReconcile(force = false): Promise { + const settings = this.deps.getSettings(); + if (!settings.connectorId) return; + if (!force && settings.lastReconcileAt) { + if (Date.now() - settings.lastReconcileAt < RECONCILE_MIN_INTERVAL_MS) return; + } + + this.setStatus("syncing", "Reconciling vault with server…"); + try { + const manifest = await this.deps.apiClient.getManifest(settings.vaultId); + const remote = manifest.entries ?? {}; + await this.diffAndQueue(settings, remote); + await this.deps.saveSettings((s) => { + s.lastReconcileAt = Date.now(); + s.tombstones = pruneTombstones(s.tombstones); + }); + await this.flushQueue(); + } catch (err) { + this.classifyAndStatus(err, "Reconcile failed"); + } + } + + private async diffAndQueue( + settings: SyncEngineSettings, + remote: Record, + ): Promise { + const localFiles = this.deps.app.vault.getFiles().filter((f) => { + if (!this.shouldTrack(f)) return false; + if (this.isExcluded(f.path, settings)) return false; + return true; + }); + const localPaths = new Set(localFiles.map((f) => f.path)); + + // Local-only or content-changed → upsert. + for (const file of localFiles) { + const remoteEntry = remote[file.path]; + if (!remoteEntry) { + this.deps.queue.enqueueUpsert(file.path); + continue; + } + if (file.stat.mtime > remoteEntry.mtime + 1000) { + this.deps.queue.enqueueUpsert(file.path); + continue; + } + if (this.isMarkdown(file)) { + const content = await this.deps.app.vault.cachedRead(file); + const hash = await computeContentHash(content); + if (hash !== remoteEntry.hash) { + this.deps.queue.enqueueUpsert(file.path); + } + } + } + + // Remote-only → delete, but only if NOT a fresh tombstone (which + // the queue will deliver) and NOT a path we already plan to upsert. + for (const path of Object.keys(remote)) { + if (localPaths.has(path)) continue; + const tombstone = settings.tombstones[path]; + if (tombstone && Date.now() - tombstone < TOMBSTONE_TTL_MS) continue; + this.deps.queue.enqueueDelete(path); + } + } + + // ---- status helpers --------------------------------------------------- + + private setStatus(kind: StatusKind, detail?: string): void { + this.deps.setStatus({ kind, detail, queueDepth: this.deps.queue.size }); + } + + private queueStatusKind(): StatusKind { + if (this.deps.queue.size > 0) return "queued"; + return "idle"; + } + + private statusDetail(): string | undefined { + const settings = this.deps.getSettings(); + if (settings.lastSyncAt) { + return `Last sync ${formatRelative(settings.lastSyncAt)}`; + } + return undefined; + } + + private handleStartupError(err: unknown): void { + if (err instanceof AuthError) { + this.setStatus("auth-error", err.message); + return; + } + if (err instanceof TransientError) { + this.setStatus("offline", err.message); + return; + } + this.setStatus("error", (err as Error).message ?? "Unknown error"); + } + + private classify(err: unknown): "ack" | "retry" | "drop" | "stop" { + if (err instanceof AuthError) { + this.setStatus("auth-error", err.message); + return "stop"; + } + if (err instanceof TransientError) { + this.setStatus("offline", err.message); + return "stop"; + } + if (err instanceof PermanentError) { + console.warn("SurfSense: permanent error, dropping batch", err); + new Notice(`SurfSense: ${err.message}`); + return "drop"; + } + console.error("SurfSense: unknown error", err); + return "retry"; + } + + private classifyAndStatus(err: unknown, prefix: string): void { + this.classify(err); + this.setStatus(this.queueStatusKind(), `${prefix}: ${(err as Error).message}`); + } + + // ---- predicates ------------------------------------------------------- + + private shouldTrack(file: TAbstractFile): boolean { + if (!isTFile(file)) return false; + const settings = this.deps.getSettings(); + if (!settings.includeAttachments && !this.isMarkdown(file)) return false; + return true; + } + + private isExcluded(path: string, settings: SyncEngineSettings): boolean { + return isExcluded(path, settings.excludePatterns); + } + + private isMarkdown(file: TAbstractFile): boolean { + return isTFile(file) && file.extension.toLowerCase() === "md"; + } +} + +function isTFile(f: TAbstractFile): f is TFile { + return f instanceof TFile; +} + +function bufferToHex(buf: ArrayBuffer): string { + const view = new Uint8Array(buf); + let hex = ""; + for (let i = 0; i < view.length; i++) hex += (view[i] ?? 0).toString(16).padStart(2, "0"); + return hex; +} + +function formatRelative(ts: number): string { + const diff = Date.now() - ts; + if (diff < 60_000) return "just now"; + if (diff < 3600_000) return `${Math.round(diff / 60_000)}m ago`; + if (diff < 86_400_000) return `${Math.round(diff / 3600_000)}h ago`; + return `${Math.round(diff / 86_400_000)}d ago`; +} + +function pruneTombstones(tombstones: Record): Record { + const out: Record = {}; + const cutoff = Date.now() - TOMBSTONE_TTL_MS; + for (const [k, v] of Object.entries(tombstones)) { + if (v >= cutoff) out[k] = v; + } + return out; +} diff --git a/surfsense_obsidian/src/types.ts b/surfsense_obsidian/src/types.ts new file mode 100644 index 000000000..8b353c2f4 --- /dev/null +++ b/surfsense_obsidian/src/types.ts @@ -0,0 +1,145 @@ +/** + * Shared types for the SurfSense Obsidian plugin. + * + * Kept in a leaf module with no other src/ imports so it can be imported + * from anywhere (settings, api-client, sync-engine, status-bar, main) + * without creating cycles. + */ + +export interface SurfsensePluginSettings { + serverUrl: string; + apiToken: string; + searchSpaceId: number | null; + connectorId: number | null; + vaultId: string; + vaultName: string; + deviceId: string; + deviceLabel: string; + syncMode: "auto" | "manual"; + excludePatterns: string[]; + includeAttachments: boolean; + lastSyncAt: number | null; + lastReconcileAt: number | null; + filesSynced: number; + queue: QueueItem[]; + tombstones: Record; +} + +export const DEFAULT_SETTINGS: SurfsensePluginSettings = { + serverUrl: "https://api.surfsense.com", + apiToken: "", + searchSpaceId: null, + connectorId: null, + vaultId: "", + vaultName: "", + deviceId: "", + deviceLabel: "", + syncMode: "auto", + excludePatterns: [".trash", "_attachments", "templates"], + includeAttachments: false, + lastSyncAt: null, + lastReconcileAt: null, + filesSynced: 0, + queue: [], + tombstones: {}, +}; + +export type QueueOp = "upsert" | "delete" | "rename"; + +export interface UpsertItem { + op: "upsert"; + path: string; + enqueuedAt: number; + attempt: number; +} + +export interface DeleteItem { + op: "delete"; + path: string; + enqueuedAt: number; + attempt: number; +} + +export interface RenameItem { + op: "rename"; + oldPath: string; + newPath: string; + enqueuedAt: number; + attempt: number; +} + +export type QueueItem = UpsertItem | DeleteItem | RenameItem; + +export interface NotePayload { + vault_id: string; + path: string; + name: string; + extension: string; + content: string; + frontmatter: Record; + tags: string[]; + headings: HeadingRef[]; + resolved_links: string[]; + unresolved_links: string[]; + embeds: string[]; + aliases: string[]; + content_hash: string; + mtime: number; + ctime: number; + [key: string]: unknown; +} + +export interface HeadingRef { + heading: string; + level: number; +} + +export interface SearchSpace { + id: number; + name: string; + description?: string; + [key: string]: unknown; +} + +export interface ConnectResponse { + connector_id: number; + vault_id: string; + search_space_id: number; + api_version: string; + capabilities: string[]; + server_time_utc: string; + [key: string]: unknown; +} + +export interface HealthResponse { + api_version: string; + capabilities: string[]; + server_time_utc: string; + [key: string]: unknown; +} + +export interface ManifestEntry { + hash: string; + mtime: number; + [key: string]: unknown; +} + +export interface ManifestResponse { + vault_id: string; + entries: Record; + [key: string]: unknown; +} + +export type StatusKind = + | "idle" + | "syncing" + | "queued" + | "offline" + | "auth-error" + | "error"; + +export interface StatusState { + kind: StatusKind; + detail?: string; + queueDepth: number; +} diff --git a/surfsense_obsidian/styles.css b/surfsense_obsidian/styles.css index 71cc60fd4..6ad450091 100644 --- a/surfsense_obsidian/styles.css +++ b/surfsense_obsidian/styles.css @@ -1,8 +1,66 @@ /* + * SurfSense Obsidian plugin styles. Kept tiny on purpose — Obsidian + * theming should drive most of the look; we only add the bits we + * cannot express via the standard PluginSettingTab/Setting components. + */ -This CSS file will be included with your plugin, and -available in the app when your plugin is enabled. +.surfsense-status { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 0 6px; + cursor: default; +} -If your plugin does not need CSS, delete this file. +.surfsense-status__icon { + display: inline-flex; + width: 14px; + height: 14px; +} -*/ +.surfsense-status__icon svg { + width: 14px; + height: 14px; +} + +.surfsense-status__text { + font-size: var(--font-ui-smaller); +} + +.surfsense-status--ok .surfsense-status__icon { + color: var(--color-green); +} + +.surfsense-status--syncing .surfsense-status__icon { + color: var(--color-blue); +} + +.surfsense-status--warn .surfsense-status__icon { + color: var(--color-yellow); +} + +.surfsense-status--err .surfsense-status__icon { + color: var(--color-red); +} + +.surfsense-settings__status { + display: grid; + grid-template-columns: minmax(120px, max-content) 1fr; + row-gap: 4px; + column-gap: 12px; + margin: 8px 0 16px; +} + +.surfsense-settings__status-row { + display: contents; +} + +.surfsense-settings__status-label { + color: var(--text-muted); + font-size: var(--font-ui-smaller); +} + +.surfsense-settings__status-value { + font-size: var(--font-ui-smaller); + word-break: break-word; +} diff --git a/surfsense_obsidian/versions.json b/surfsense_obsidian/versions.json index 26382a157..8b02889bb 100644 --- a/surfsense_obsidian/versions.json +++ b/surfsense_obsidian/versions.json @@ -1,3 +1,3 @@ { - "1.0.0": "0.15.0" + "0.1.0": "1.4.0" } diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index caa85ba2d..e5233a20d 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -1,5 +1,5 @@ import { format } from "date-fns"; -import { useAtom, useAtomValue, useSetAtom } from "jotai"; +import { useAtom, useAtomValue } from "jotai"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; @@ -10,17 +10,11 @@ import { updateConnectorMutationAtom, } from "@/atoms/connectors/connector-mutation.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; -import { - folderWatchDialogOpenAtom, - folderWatchInitialFolderAtom, -} from "@/atoms/folder-sync/folder-sync.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { EnumConnectorName } from "@/contracts/enums/connector"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { searchSourceConnector } from "@/contracts/types/connector.types"; -import { usePlatform } from "@/hooks/use-platform"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { isSelfHosted } from "@/lib/env-config"; import { trackConnectorConnected, trackConnectorDeleted, @@ -68,10 +62,6 @@ export const useConnectorDialog = () => { const { mutateAsync: updateConnector } = useAtomValue(updateConnectorMutationAtom); const { mutateAsync: deleteConnector } = useAtomValue(deleteConnectorMutationAtom); const { mutateAsync: createConnector } = useAtomValue(createConnectorMutationAtom); - const setFolderWatchOpen = useSetAtom(folderWatchDialogOpenAtom); - const setFolderWatchInitialFolder = useSetAtom(folderWatchInitialFolderAtom); - const { isDesktop } = usePlatform(); - const selfHosted = isSelfHosted(); // Use global atom for dialog open state so it can be controlled from anywhere const [isOpen, setIsOpen] = useAtom(connectorDialogOpenAtom); @@ -447,29 +437,13 @@ export const useConnectorDialog = () => { } }, [searchSpaceId, createConnector, refetchAllConnectors, setIsOpen]); - // Handle connecting non-OAuth connectors (like Tavily API) + // Handle connecting non-OAuth connectors (like Tavily API, Obsidian plugin, etc.) const handleConnectNonOAuth = useCallback( (connectorType: string) => { if (!searchSpaceId) return; - - // Handle Obsidian specifically on Desktop & Cloud - if (connectorType === EnumConnectorName.OBSIDIAN_CONNECTOR && !selfHosted && isDesktop) { - setIsOpen(false); - setFolderWatchInitialFolder(null); - setFolderWatchOpen(true); - return; - } - setConnectingConnectorType(connectorType); }, - [ - searchSpaceId, - selfHosted, - isDesktop, - setIsOpen, - setFolderWatchOpen, - setFolderWatchInitialFolder, - ] + [searchSpaceId] ); // Handle submitting connect form diff --git a/versions.json b/versions.json new file mode 100644 index 000000000..8b02889bb --- /dev/null +++ b/versions.json @@ -0,0 +1,3 @@ +{ + "0.1.0": "1.4.0" +} From b5c9388c8acdcc8ffdf5bea6996502a733dee617 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 20 Apr 2026 18:19:30 +0530 Subject: [PATCH 007/299] feat: refine Obsidian plugin routes and schemas for improved device management and API stability --- .../app/routes/obsidian_plugin_routes.py | 119 ++++------ .../app/schemas/obsidian_plugin.py | 63 +---- surfsense_obsidian/src/api-client.ts | 2 - surfsense_obsidian/src/main.ts | 50 ++-- surfsense_obsidian/src/settings.ts | 45 +--- surfsense_obsidian/src/sync-engine.ts | 48 ++-- surfsense_obsidian/src/types.ts | 15 +- .../components/obsidian-config.tsx | 219 +++++------------- .../views/connector-edit-view.tsx | 6 +- 9 files changed, 182 insertions(+), 385 deletions(-) diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index c7656332d..0d2ce703d 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -1,31 +1,8 @@ -""" -Obsidian plugin ingestion routes. +"""Obsidian plugin ingestion routes (``/api/v1/obsidian/*``). -This is the public surface that the SurfSense Obsidian plugin -(``surfsense_obsidian/``) speaks to. It is a separate router from the -legacy server-path Obsidian connector — the legacy code stays in place -until the ``obsidian-legacy-cleanup`` plan ships. - -Endpoints ---------- - -- ``GET /api/v1/obsidian/health`` — version handshake -- ``POST /api/v1/obsidian/connect`` — register or get a vault row -- ``POST /api/v1/obsidian/sync`` — batch upsert -- ``POST /api/v1/obsidian/rename`` — batch rename -- ``DELETE /api/v1/obsidian/notes`` — batch soft-delete -- ``GET /api/v1/obsidian/manifest`` — reconcile manifest - -Auth contract -------------- - -Every endpoint requires ``Depends(current_active_user)`` — the same JWT -bearer the rest of the API uses; future PAT migration is transparent. - -API stability is provided by the ``/api/v1/...`` URL prefix and the -``capabilities`` array advertised on ``/health`` (additive only). There -is no plugin-version gate; "your plugin is out of date" notices are -delegated to Obsidian's built-in community-store updater. +Wire surface for the ``surfsense_obsidian/`` plugin. API stability is the +``/api/v1/`` prefix plus the additive ``capabilities`` array on /health; +no plugin-version gate. """ from __future__ import annotations @@ -67,14 +44,10 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/obsidian", tags=["obsidian-plugin"]) -# Bumped manually whenever the wire contract gains a non-additive change. -# Additive (extra='ignore'-safe) changes do NOT bump this. +# Bumped only on non-additive wire changes; additive ones ride extra='ignore'. OBSIDIAN_API_VERSION = "1" -# Capabilities advertised on /health and /connect. Plugins use this list -# for feature gating ("does this server understand attachments_v2?"). Add -# new strings, never rename/remove existing ones — older plugins ignore -# unknown entries safely. +# Plugins feature-gate on these. Add entries, never rename or remove. OBSIDIAN_CAPABILITIES: list[str] = ["sync", "rename", "delete", "manifest"] @@ -90,18 +63,41 @@ def _build_handshake() -> dict[str, object]: } +def _upsert_device( + existing_devices: object, + device_id: str, + now_iso: str, +) -> dict[str, dict[str, str]]: + """Upsert ``device_id`` into ``{device_id: {first_seen_at, last_seen_at}}``. + + Keyed by device_id for O(1) dedup; ``len(devices)`` is the count. + Timestamps are kept for a future stale-device pruner. + """ + devices: dict[str, dict[str, str]] = {} + if isinstance(existing_devices, dict): + for key, val in existing_devices.items(): + if not isinstance(key, str) or not key or not isinstance(val, dict): + continue + devices[key] = { + "first_seen_at": str(val.get("first_seen_at") or now_iso), + "last_seen_at": str(val.get("last_seen_at") or now_iso), + } + + prev = devices.get(device_id) + devices[device_id] = { + "first_seen_at": prev["first_seen_at"] if prev else now_iso, + "last_seen_at": now_iso, + } + return devices + + async def _resolve_vault_connector( session: AsyncSession, *, user: User, vault_id: str, ) -> SearchSourceConnector: - """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user. - - Looked up by the (user_id, connector_type, config['vault_id']) tuple - so users can have multiple vaults each backed by its own connector - row (one per search space). - """ + """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.""" result = await session.execute( select(SearchSourceConnector).where( and_( @@ -136,12 +132,7 @@ async def _ensure_search_space_access( user: User, search_space_id: int, ) -> SearchSpace: - """Confirm the user owns the requested search space. - - Plugin currently does not support shared search spaces (RBAC roles) - — that's a follow-up. Restricting to owner-only here keeps the - surface narrow and avoids leaking other members' connectors. - """ + """Owner-only access to the search space (shared spaces are a follow-up).""" result = await session.execute( select(SearchSpace).where( and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id) @@ -168,11 +159,7 @@ async def _ensure_search_space_access( async def obsidian_health( user: User = Depends(current_active_user), ) -> HealthResponse: - """Return the API contract handshake. - - The plugin calls this once per ``onload`` and caches the result for - capability-gating decisions. - """ + """Return the API contract handshake; plugin caches it per onload.""" return HealthResponse( **_build_handshake(), server_time_utc=datetime.now(UTC), @@ -187,9 +174,9 @@ async def obsidian_connect( ) -> ConnectResponse: """Register a vault, or return the existing connector row. - Idempotent on the (user_id, OBSIDIAN_CONNECTOR, vault_id) tuple so - re-installing the plugin or reconnecting from a new device picks up - the same connector — and therefore the same documents. + Idempotent on (user_id, OBSIDIAN_CONNECTOR, vault_id). Called on every + plugin onload as a heartbeat — upserts ``device_id`` into + ``config['devices']`` so the web UI can show a "Devices: N" tile. """ await _ensure_search_space_access( session, user=user, search_space_id=payload.search_space_id @@ -215,27 +202,31 @@ async def obsidian_connect( if existing is not None: cfg = dict(existing.config or {}) + devices = _upsert_device(cfg.get("devices"), payload.device_id, now_iso) cfg.update( { "vault_id": payload.vault_id, "vault_name": payload.vault_name, "source": "plugin", "plugin_version": payload.plugin_version, - "device_id": payload.device_id, + "devices": devices, + "device_count": len(devices), "last_connect_at": now_iso, } ) - if payload.device_label: - cfg["device_label"] = payload.device_label cfg.pop("legacy", None) cfg.pop("vault_path", None) existing.config = cfg + # Re-stamp on every connect so vault renames in Obsidian propagate; + # the web UI hides the Name input for Obsidian connectors. + existing.name = f"Obsidian — {payload.vault_name}" existing.is_indexable = False existing.search_space_id = payload.search_space_id await session.commit() await session.refresh(existing) connector = existing else: + devices = _upsert_device(None, payload.device_id, now_iso) connector = SearchSourceConnector( name=f"Obsidian — {payload.vault_name}", connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR, @@ -245,8 +236,8 @@ async def obsidian_connect( "vault_name": payload.vault_name, "source": "plugin", "plugin_version": payload.plugin_version, - "device_id": payload.device_id, - "device_label": payload.device_label, + "devices": devices, + "device_count": len(devices), "files_synced": 0, "last_connect_at": now_iso, }, @@ -271,11 +262,7 @@ async def obsidian_sync( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> dict[str, object]: - """Batch-upsert notes pushed by the plugin. - - Returns per-note ack so the plugin can dequeue successes and retry - failures. - """ + """Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry.""" connector = await _resolve_vault_connector( session, user=user, vault_id=payload.vault_id ) @@ -439,11 +426,7 @@ async def obsidian_manifest( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ) -> ManifestResponse: - """Return the server-side ``{path: {hash, mtime}}`` manifest. - - Used by the plugin's ``onload`` reconcile to find files that were - edited or deleted while the plugin was offline. - """ + """Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff.""" connector = await _resolve_vault_connector( session, user=user, vault_id=vault_id ) diff --git a/surfsense_backend/app/schemas/obsidian_plugin.py b/surfsense_backend/app/schemas/obsidian_plugin.py index c4c3cd8d4..5de0a093a 100644 --- a/surfsense_backend/app/schemas/obsidian_plugin.py +++ b/surfsense_backend/app/schemas/obsidian_plugin.py @@ -1,23 +1,8 @@ -""" -Obsidian Plugin connector schemas. +"""Wire schemas spoken between the SurfSense Obsidian plugin and the backend. -Wire format spoken between the SurfSense Obsidian plugin -(``surfsense_obsidian/``) and the FastAPI backend. - -Stability contract ------------------- -Every request and response schema sets ``model_config = ConfigDict(extra='ignore')``. -This is the API stability contract — not just hygiene: - -- Old plugins talking to a newer backend silently drop any new response fields - they don't understand instead of failing validation. -- New plugins talking to an older backend can include forward-looking request - fields (e.g. attachments metadata) without the older backend rejecting them. - -Hard breaking changes are reserved for the URL prefix (``/api/v2/...``). -Additive evolution is signaled via the ``capabilities`` array on -``HealthResponse`` / ``ConnectResponse`` — older plugins ignore unknown -capability strings safely. +All schemas inherit ``extra='ignore'`` from :class:`_PluginBase` so additive +field changes never break either side; hard breaks live behind a new URL +prefix (``/api/v2/...``). """ from __future__ import annotations @@ -31,22 +16,13 @@ _PLUGIN_MODEL_CONFIG = ConfigDict(extra="ignore") class _PluginBase(BaseModel): - """Base class for all plugin payload schemas. - - Carries the forward-compatibility config so subclasses don't have to - repeat it. - """ + """Base schema carrying the shared forward-compatibility config.""" model_config = _PLUGIN_MODEL_CONFIG class NotePayload(_PluginBase): - """One Obsidian note as pushed by the plugin. - - The plugin is the source of truth: ``content`` is the post-frontmatter - body, ``frontmatter``/``tags``/``headings``/etc. are precomputed by the - plugin via ``app.metadataCache`` so the backend doesn't have to re-parse. - """ + """One Obsidian note as pushed by the plugin (the source of truth).""" vault_id: str = Field(..., description="Stable plugin-generated UUID for this vault") path: str = Field(..., description="Vault-relative path, e.g. 'notes/foo.md'") @@ -68,7 +44,7 @@ class NotePayload(_PluginBase): class SyncBatchRequest(_PluginBase): - """Batch upsert. Plugin sends 10-20 notes per request to amortize HTTP overhead.""" + """Batch upsert; plugin sends 10-20 notes per request.""" vault_id: str notes: list[NotePayload] = Field(default_factory=list, max_length=100) @@ -90,8 +66,6 @@ class DeleteBatchRequest(_PluginBase): class ManifestEntry(_PluginBase): - """One row of the server-side manifest used by the plugin to reconcile.""" - hash: str mtime: datetime @@ -104,26 +78,18 @@ class ManifestResponse(_PluginBase): class ConnectRequest(_PluginBase): - """First-call handshake to register or look up a vault connector row.""" + """Vault registration / heartbeat. Replayed on every plugin onload.""" vault_id: str vault_name: str search_space_id: int plugin_version: str device_id: str - device_label: str | None = Field( - default=None, - description="User-friendly device name shown in the web UI (e.g. 'iPad Pro').", - ) class ConnectResponse(_PluginBase): - """Returned from POST /connect. - - Carries the same handshake fields as ``HealthResponse`` so the plugin - learns the contract on its very first call without an extra round-trip - to ``GET /health``. - """ + """Carries the same handshake fields as ``HealthResponse`` so the plugin + learns the contract without a separate ``GET /health`` round-trip.""" connector_id: int vault_id: str @@ -133,14 +99,7 @@ class ConnectResponse(_PluginBase): class HealthResponse(_PluginBase): - """API contract handshake. - - The plugin calls ``GET /health`` once per ``onload`` and caches the - result. ``capabilities`` is a forward-extensible string list: future - additions (``'pat_auth'``, ``'scoped_pat'``, ``'attachments_v2'``, - ``'shared_search_spaces'``...) ship without breaking older plugins - because they only enable extra behavior, never gate existing endpoints. - """ + """API contract handshake. ``capabilities`` is additive-only string list.""" api_version: str capabilities: list[str] diff --git a/surfsense_obsidian/src/api-client.ts b/surfsense_obsidian/src/api-client.ts index d686f661f..4b5ae0e33 100644 --- a/surfsense_obsidian/src/api-client.ts +++ b/surfsense_obsidian/src/api-client.ts @@ -105,7 +105,6 @@ export class SurfSenseApiClient { vaultId: string; vaultName: string; deviceId: string; - deviceLabel: string; }): Promise { return await this.request( "POST", @@ -117,7 +116,6 @@ export class SurfSenseApiClient { vault_name: input.vaultName, plugin_version: this.opts.pluginVersion, device_id: input.deviceId, - device_label: input.deviceLabel, } ); } diff --git a/surfsense_obsidian/src/main.ts b/surfsense_obsidian/src/main.ts index 34e5715a1..262886e55 100644 --- a/surfsense_obsidian/src/main.ts +++ b/surfsense_obsidian/src/main.ts @@ -11,28 +11,18 @@ import { type SurfsensePluginSettings, } from "./types"; -/** - * SurfSense plugin entry point. - * - * Replaces the obsidian-sample-plugin SampleModal/ribbon stub. Lifecycle: - * - * onload(): - * load settings → seed identity (vault_id, device_id) → - * wire api client + queue + sync engine + status bar → - * register settings tab → register vault + metadataCache events → - * register commands (resync, sync current note, open settings) → - * register status bar item → - * kick off engine.start() (health → drain → reconcile). - * - * onunload(): - * stop the queue's debounce timer; unregistered events and DOM - * handles auto-clean via the Plugin base class. - */ +/** SurfSense plugin entry point. */ export default class SurfSensePlugin extends Plugin { settings!: SurfsensePluginSettings; api!: SurfSenseApiClient; queue!: PersistentQueue; engine!: SyncEngine; + /** + * Per-install identifier kept in `app.saveLocalStorage` rather than + * `data.json`, so it does NOT travel through Obsidian Sync — each + * machine on a synced vault stays distinguishable. + */ + deviceId = ""; private statusBar: StatusBar | null = null; lastStatus: StatusState = { kind: "idle", queueDepth: 0 }; serverCapabilities: string[] = []; @@ -69,6 +59,7 @@ export default class SurfSensePlugin extends Plugin { await this.saveSettings(); this.settingTab?.renderStatus(); }, + getDeviceId: () => this.deviceId, setStatus: (s) => { this.lastStatus = s; this.statusBar?.update(s); @@ -143,8 +134,7 @@ export default class SurfSensePlugin extends Plugin { id: "open-settings", name: "Open settings", callback: () => { - // Obsidian exposes this through the Setting host on the workspace; - // fall back silently if the API moves so we never throw. + // `app.setting` isn't in the d.ts; fall back silently if it moves. type SettingHost = { open?: () => void; openTabById?: (id: string) => void; @@ -155,8 +145,7 @@ export default class SurfSensePlugin extends Plugin { }, }); - // Kick off the start sequence after Obsidian finishes its own - // startup work, so the metadataCache is warm before reconcile. + // Wait for layout so the metadataCache is warm before reconcile. this.app.workspace.onLayoutReady(() => { void this.engine.start(); }); @@ -188,13 +177,28 @@ export default class SurfSensePlugin extends Plugin { await this.saveData(this.settings); } + /** + * Mint vault_id (in data.json, travels with the vault) and device_id + * (in `app.saveLocalStorage`, stays per-install) on first run. + */ private seedIdentity(): void { if (!this.settings.vaultId) { this.settings.vaultId = generateUuid(); } - if (!this.settings.deviceId) { - this.settings.deviceId = generateUuid(); + + // loadLocalStorage / saveLocalStorage aren't in the d.ts; cast at the boundary. + const localStore = this.app as unknown as { + loadLocalStorage: (key: string) => string | null; + saveLocalStorage: (key: string, value: string | null) => void; + }; + const storageKey = "surfsense:deviceId"; + let deviceId = localStore.loadLocalStorage(storageKey); + if (!deviceId) { + deviceId = generateUuid(); + localStore.saveLocalStorage(storageKey, deviceId); } + this.deviceId = deviceId; + if (!this.settings.vaultName) { this.settings.vaultName = this.app.vault.getName(); } diff --git a/surfsense_obsidian/src/settings.ts b/surfsense_obsidian/src/settings.ts index d22b66384..224959f95 100644 --- a/surfsense_obsidian/src/settings.ts +++ b/surfsense_obsidian/src/settings.ts @@ -9,22 +9,7 @@ import { parseExcludePatterns } from "./excludes"; import type SurfSensePlugin from "./main"; import type { SearchSpace } from "./types"; -/** - * Plugin settings tab. - * - * Replaces the obsidian-sample-plugin SampleSettingTab stub. Same module - * path so existing imports from main.ts keep resolving. - * - * Surface mirrors the per-plan list: - * server URL · api token · search space · vault name · sync mode · - * exclude patterns · include attachments · status panel. - * - * Vault id, device id, and device label are auto-generated UUIDs the - * first time settings load — they're displayed (read-only) so users can - * audit them, but never editable. Vault id is decoupled from the OS - * folder name so renaming the vault doesn't invalidate the connector - * (edge case #5 from the plan). - */ +/** Plugin settings tab. */ export class SurfSenseSettingTab extends PluginSettingTab { private readonly plugin: SurfSensePlugin; @@ -151,21 +136,6 @@ export class SurfSenseSettingTab extends PluginSettingTab { }), ); - new Setting(containerEl) - .setName("Device label") - .setDesc( - "Optional human-readable label shown next to the device ID in the Surfsense web app.", - ) - .addText((text) => - text - .setPlaceholder("My laptop") - .setValue(settings.deviceLabel) - .onChange(async (value) => { - this.plugin.settings.deviceLabel = value.trim(); - await this.plugin.saveSettings(); - }), - ); - new Setting(containerEl) .setName("Sync mode") .setDesc("Auto syncs on every edit. Manual only syncs when you trigger it via the command palette.") @@ -214,19 +184,16 @@ export class SurfSenseSettingTab extends PluginSettingTab { new Setting(containerEl) .setName("Vault ID") - .setDesc("Stable identifier for this vault. Used by the backend to keep separate vaults distinct even if their folder names change.") + .setDesc( + "Stable identifier for this vault. Used by the backend to keep separate vaults distinct even if their folder names change.", + ) .addText((text) => { text.inputEl.disabled = true; text.setValue(settings.vaultId); }); - new Setting(containerEl) - .setName("Device ID") - .setDesc("Stable identifier for this install. Used by the backend so you can revoke a single device without disconnecting the others.") - .addText((text) => { - text.inputEl.disabled = true; - text.setValue(settings.deviceId); - }); + // Device ID is deliberately not exposed: it's an opaque per-install UUID + // (see seedIdentity in main.ts) and the web UI only shows a device count. new Setting(containerEl).setName("Status").setHeading(); this.statusEl = containerEl.createDiv({ cls: "surfsense-settings__status" }); diff --git a/surfsense_obsidian/src/sync-engine.ts b/surfsense_obsidian/src/sync-engine.ts index ce22b69c1..b2c1b0a5a 100644 --- a/surfsense_obsidian/src/sync-engine.ts +++ b/surfsense_obsidian/src/sync-engine.ts @@ -19,20 +19,8 @@ import type { /** * Owner of "what does the vault look like vs the server" reasoning. * - * Onload sequence (per plan §p4_plugin_sync_engine, in this exact order): - * 1. apiClient.health() — proves connectivity and pulls the capabilities - * handshake before we issue any sync traffic. - * 2. Cache health.capabilities + api_version on the plugin instance - * so feature gating (e.g. "attachments_v2" before syncing binaries) - * reads from local state instead of round-tripping. - * 3. Drain queue — items persisted from the previous session land first. - * 4. Reconcile — GET /manifest, diff against vault, queue uploads/deletes. - * 5. Subscribe events — only after the above so the user's first edit - * after launching Obsidian doesn't race with the manifest diff. - * - * Reconcile skips itself if last successful reconcile is < RECONCILE_MIN_INTERVAL_MS - * ago. ConnectResponse already carries handshake fields so first connect - * does not need a separate /health round-trip. + * Start order: connect (or fall back to /health) → drain queue → reconcile → + * subscribe events. Reconcile no-ops if last run was < RECONCILE_MIN_INTERVAL_MS ago. */ export interface SyncEngineDeps { @@ -41,6 +29,8 @@ export interface SyncEngineDeps { queue: PersistentQueue; getSettings: () => SyncEngineSettings; saveSettings: (mut: (s: SyncEngineSettings) => void) => Promise; + /** Per-install id sourced from app.saveLocalStorage (not synced data.json). */ + getDeviceId: () => string; setStatus: (s: StatusState) => void; onCapabilities: (caps: string[], apiVersion: string) => void; } @@ -50,8 +40,6 @@ export interface SyncEngineSettings { vaultName: string; connectorId: number | null; searchSpaceId: number | null; - deviceId: string; - deviceLabel: string; excludePatterns: string[]; includeAttachments: boolean; syncMode: "auto" | "manual"; @@ -86,22 +74,27 @@ export class SyncEngine { /** Run the onload sequence described in this file's docstring. */ async start(): Promise { this.setStatus("syncing", "Connecting to SurfSense…"); - try { - const health = await this.deps.apiClient.health(); - this.applyHealth(health); - } catch (err) { - this.handleStartupError(err); - return; - } const settings = this.deps.getSettings(); - if (!settings.connectorId || !settings.searchSpaceId) { - // No connector yet — settings tab will trigger ensureConnect once - // the user picks a search space, then re-call start(). + if (!settings.searchSpaceId) { + // No target yet — bare /health probe still surfaces auth/network errors. + try { + const health = await this.deps.apiClient.health(); + this.applyHealth(health); + } catch (err) { + this.handleStartupError(err); + return; + } this.setStatus("idle", "Pick a search space in settings to start syncing."); return; } + // Re-announce on every load: /connect doubles as the device heartbeat + // that bumps last_seen_at and powers the "Devices: N" tile in the web UI. + await this.ensureConnected(); + + if (!this.deps.getSettings().connectorId) return; + await this.flushQueue(); await this.maybeReconcile(); this.setStatus(this.queueStatusKind(), undefined); @@ -119,8 +112,7 @@ export class SyncEngine { searchSpaceId: settings.searchSpaceId, vaultId: settings.vaultId, vaultName: settings.vaultName, - deviceId: settings.deviceId, - deviceLabel: settings.deviceLabel, + deviceId: this.deps.getDeviceId(), }); this.applyHealth(resp); await this.deps.saveSettings((s) => { diff --git a/surfsense_obsidian/src/types.ts b/surfsense_obsidian/src/types.ts index 8b353c2f4..33b0d01a7 100644 --- a/surfsense_obsidian/src/types.ts +++ b/surfsense_obsidian/src/types.ts @@ -1,20 +1,15 @@ -/** - * Shared types for the SurfSense Obsidian plugin. - * - * Kept in a leaf module with no other src/ imports so it can be imported - * from anywhere (settings, api-client, sync-engine, status-bar, main) - * without creating cycles. - */ +/** Shared types for the SurfSense Obsidian plugin. Leaf module — no src/ imports. */ export interface SurfsensePluginSettings { serverUrl: string; apiToken: string; searchSpaceId: number | null; connectorId: number | null; + /** UUID for the vault — lives here so Obsidian Sync replicates it across devices. */ vaultId: string; vaultName: string; - deviceId: string; - deviceLabel: string; + // Per-install deviceId is NOT in this interface on purpose: it lives in + // app.saveLocalStorage so it stays distinct on each device. See seedIdentity(). syncMode: "auto" | "manual"; excludePatterns: string[]; includeAttachments: boolean; @@ -32,8 +27,6 @@ export const DEFAULT_SETTINGS: SurfsensePluginSettings = { connectorId: null, vaultId: "", vaultName: "", - deviceId: "", - deviceLabel: "", syncMode: "auto", excludePatterns: [".trash", "_attachments", "templates"], includeAttachments: false, diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx index acea1c51b..feca9c35e 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx @@ -1,19 +1,11 @@ "use client"; -import { AlertTriangle, Check, Copy, Download, Info } from "lucide-react"; -import { type FC, useCallback, useMemo, useRef, useState } from "react"; +import { AlertTriangle, Download, Info } from "lucide-react"; +import { type FC, useMemo } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { useApiKey } from "@/hooks/use-api-key"; -import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils"; import type { ConnectorConfigProps } from "../index"; -export interface ObsidianConfigProps extends ConnectorConfigProps { - onNameChange?: (name: string) => void; -} - const PLUGIN_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases?q=obsidian&expanded=true"; @@ -27,55 +19,32 @@ function formatTimestamp(value: unknown): string { /** * Obsidian connector config view. * - * Renders one of two modes depending on the connector's `config`: + * Read-only on purpose: the plugin owns vault identity, so the connector's + * display name is auto-derived from `payload.vault_name` server-side on + * every `/connect` (see `obsidian_plugin_routes.obsidian_connect`). The + * web UI doesn't expose a Name input or a Save button for Obsidian (the + * latter is suppressed in `connector-edit-view.tsx`). + * + * Renders one of three modes depending on the connector's `config`: * * 1. **Plugin connector** (`config.source === "plugin"`) — read-only stats * panel showing what the plugin most recently reported. * 2. **Legacy server-path connector** (`config.legacy === true`, set by the - * Phase 3 alembic) — migration banner plus an "Install Plugin" CTA. - * The user's existing notes stay searchable; only background sync stops. + * Phase 3 alembic) — migration banner, an "Install Plugin" CTA, and a + * short "how to migrate" checklist that ends with the user pressing the + * standard Disconnect button (which deletes this connector along with + * every document it previously indexed). + * 3. **Unknown** — fallback for rows that escaped the alembic; suggests a + * clean re-install. */ -export const ObsidianConfig: FC = ({ - connector, - onNameChange, -}) => { - const [name, setName] = useState(connector.name || ""); +export const ObsidianConfig: FC = ({ connector }) => { const config = (connector.config ?? {}) as Record; const isLegacy = config.legacy === true; const isPlugin = config.source === "plugin"; - const handleNameChange = (value: string) => { - setName(value); - onNameChange?.(value); - }; - - return ( -
- {/* Connector name (always editable) */} -
-
- - handleNameChange(e.target.value)} - placeholder="My Obsidian Vault" - className="border-slate-400/20 focus-visible:border-slate-400/40" - /> -

- A friendly name to identify this connector. -

-
-
- - {isLegacy ? ( - - ) : isPlugin ? ( - - ) : ( - - )} -
- ); + if (isLegacy) return ; + if (isPlugin) return ; + return ; }; const LegacyBanner: FC = () => { @@ -84,14 +53,12 @@ const LegacyBanner: FC = () => { - This connector has been migrated + Sync stopped — install the plugin to migrate - This Obsidian connector used the legacy server-path method, which has - been removed. To resume syncing, install the SurfSense Obsidian - plugin and connect with this account. Your existing notes remain - searchable. After the plugin re-indexes your vault, you can delete - this connector to remove older copies. + This Obsidian connector used the legacy server-path scanner, which has been removed. The + notes already indexed remain searchable, but they no longer reflect changes made in your + vault. @@ -107,7 +74,25 @@ const LegacyBanner: FC = () => { - +
+

How to migrate

+
    +
  1. Install the SurfSense Obsidian plugin using the button above.
  2. +
  3. + In Obsidian, open Settings → SurfSense, sign in, pick a search space, and wait for the + first sync to finish. +
  4. +
  5. + Confirm the new "Obsidian — <vault>" connector shows your notes, then return here + and use the Disconnect button below to remove this legacy connector. +
  6. +
+

+ Heads up: Disconnect also deletes every document this connector previously indexed. Make + sure the plugin has finished its first sync before you disconnect, otherwise your Obsidian + notes will disappear from search until the plugin re-indexes them. +

+
); }; @@ -115,6 +100,14 @@ const LegacyBanner: FC = () => { const PluginStats: FC<{ config: Record }> = ({ config }) => { const stats: { label: string; value: string }[] = useMemo(() => { const filesSynced = config.files_synced; + // Prefer the stamped count; fall back to len(devices) for rows the + // backend hasn't re-stamped yet. + const deviceCount = + typeof config.device_count === "number" + ? config.device_count + : config.devices && typeof config.devices === "object" + ? Object.keys(config.devices as Record).length + : null; return [ { label: "Vault", value: (config.vault_name as string) || "—" }, { @@ -122,11 +115,8 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => { value: (config.plugin_version as string) || "—", }, { - label: "Device", - value: - (config.device_label as string) || - (config.device_id as string) || - "—", + label: "Devices", + value: deviceCount !== null ? deviceCount.toLocaleString() : "—", }, { label: "Last sync", @@ -134,8 +124,7 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => { }, { label: "Files synced", - value: - typeof filesSynced === "number" ? filesSynced.toLocaleString() : "—", + value: typeof filesSynced === "number" ? filesSynced.toLocaleString() : "—", }, ]; }, [config]); @@ -146,8 +135,8 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => { Plugin connected - Edits in Obsidian sync over HTTPS. To stop syncing, disable or - uninstall the plugin in Obsidian, or delete this connector. + Edits in Obsidian sync over HTTPS. To stop syncing, disable or uninstall the plugin in + Obsidian, or delete this connector. @@ -162,9 +151,7 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => {
{stat.label}
-
- {stat.value} -
+
{stat.value}
))} @@ -178,98 +165,8 @@ const UnknownConnectorState: FC = () => ( Unrecognized config - This connector has neither plugin metadata nor a legacy marker. It may - predate the migration — you can safely delete it and re-install the - SurfSense Obsidian plugin to resume syncing. + This connector has neither plugin metadata nor a legacy marker. It may predate the migration — + you can safely delete it and re-install the SurfSense Obsidian plugin to resume syncing. ); - -const ApiKeyReminder: FC = () => { - const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); - const [copiedUrl, setCopiedUrl] = useState(false); - const urlCopyTimerRef = useRef | undefined>( - undefined - ); - - const backendUrl = - process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ?? "https://api.surfsense.com"; - - const copyServerUrl = useCallback(async () => { - const ok = await copyToClipboardUtil(backendUrl); - if (!ok) return; - setCopiedUrl(true); - if (urlCopyTimerRef.current) clearTimeout(urlCopyTimerRef.current); - urlCopyTimerRef.current = setTimeout(() => setCopiedUrl(false), 2000); - }, [backendUrl]); - - return ( -
-

- Plugin connection details -

-

- Paste these into the plugin's settings inside Obsidian. -

- -
- - {isLoading ? ( -
- ) : ( -
-
-

- {apiKey || "No API key available"} -

-
- -
- )} -

- Token expires after 24 hours; long-lived tokens are coming in a - future release. -

-
- -
- -
-
-

- {backendUrl} -

-
- -
-
-
- ); -}; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index e19600ab2..256e9a4e7 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -87,6 +87,10 @@ export const ConnectorEditView: FC = ({ const isAuthExpired = connector.config?.auth_expired === true; const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type]; const [reauthing, setReauthing] = useState(false); + // Obsidian is plugin-driven: name + config are owned by the plugin, so + // the web edit view has nothing the user can persist back. Hide Save + // (and re-auth, which Obsidian never uses) entirely for that type. + const isPluginManagedReadOnly = connector.connector_type === EnumConnectorName.OBSIDIAN_CONNECTOR; const handleReauth = useCallback(async () => { const spaceId = searchSpaceId ?? searchSpaceIdAtom; @@ -412,7 +416,7 @@ export const ConnectorEditView: FC = ({ Disconnect )} - {isAuthExpired && reauthEndpoint ? ( + {isPluginManagedReadOnly ? null : isAuthExpired && reauthEndpoint ? ( - - -
-

How to migrate

-
    -
  1. Install the SurfSense Obsidian plugin using the button above.
  2. -
  3. - In Obsidian, open Settings → SurfSense, sign in, pick a search space, and wait for the - first sync to finish. -
  4. -
  5. - Confirm the new "Obsidian — <vault>" connector shows your notes, then return here - and use the Disconnect button below to remove this legacy connector. -
  6. -
-

- Heads up: Disconnect also deletes every document this connector previously indexed. Make - sure the plugin has finished its first sync before you disconnect, otherwise your Obsidian - notes will disappear from search until the plugin re-indexes them. -

-
-
- ); -}; - const PluginStats: FC<{ config: Record }> = ({ config }) => { const vaultId = typeof config.vault_id === "string" ? config.vault_id : null; const [stats, setStats] = useState(null); @@ -179,8 +114,8 @@ const UnknownConnectorState: FC = () => ( Unrecognized config - This connector has neither plugin metadata nor a legacy marker. It may predate the migration — - you can safely delete it and re-install the SurfSense Obsidian plugin to resume syncing. + This connector is missing plugin metadata. Delete it, then reconnect your vault from the + SurfSense Obsidian plugin so sync can resume. ); diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx index 8a0ef5ae1..e58542923 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx @@ -111,7 +111,9 @@ export const ConnectorConnectView: FC = ({ : getConnectorTypeDisplay(connectorType)}

- Enter your connection details + {connectorType === "OBSIDIAN_CONNECTOR" + ? "Follow the plugin setup steps below" + : "Enter your connection details"}

diff --git a/surfsense_web/content/docs/connectors/obsidian.mdx b/surfsense_web/content/docs/connectors/obsidian.mdx index c8475c97f..c4d50cf34 100644 --- a/surfsense_web/content/docs/connectors/obsidian.mdx +++ b/surfsense_web/content/docs/connectors/obsidian.mdx @@ -1,143 +1,60 @@ --- title: Obsidian -description: Connect your Obsidian vault to SurfSense +description: Sync your Obsidian vault with the SurfSense plugin --- -# Obsidian Integration Setup Guide +# Obsidian Plugin Setup Guide -This guide walks you through connecting your Obsidian vault to SurfSense for note search and AI-powered insights. - - - This connector requires direct file system access and only works with self-hosted SurfSense installations. - +SurfSense integrates with Obsidian through the SurfSense Obsidian plugin. +The old server-side vault path scanner is no longer supported. ## How it works -The Obsidian connector scans your local Obsidian vault directory and indexes all Markdown files. It preserves your note structure and extracts metadata from YAML frontmatter. +The plugin runs inside your Obsidian app and pushes note updates to SurfSense over HTTPS. +This works for cloud and self-hosted deployments, including desktop and mobile clients. -- For follow-up indexing runs, the connector uses content hashing to skip unchanged files for faster sync. -- Indexing should be configured to run periodically, so updates should appear in your search results within minutes. - ---- - -## What Gets Indexed +## What gets indexed | Content Type | Description | |--------------|-------------| -| Markdown Files | All `.md` files in your vault | -| Frontmatter | YAML metadata (title, tags, aliases, dates) | -| Wiki Links | Links between notes (`[[note]]`) | -| Inline Tags | Tags throughout your notes (`#tag`) | -| Note Content | Full content with intelligent chunking | +| Markdown files | Note content (`.md`) | +| Frontmatter | YAML metadata like title, tags, aliases, dates | +| Wiki links | Linked notes (`[[note]]`) | +| Tags | Inline and frontmatter tags | +| Vault metadata | Vault and path metadata used for deep links and sync state | - - Binary files and attachments are not indexed by default. Enable "Include Attachments" to index embedded files. - +## Quick start ---- - -## Quick Start (Local Installation) - -1. Navigate to **Connectors** → **Add Connector** → **Obsidian** -2. Enter your vault path: `/Users/yourname/Documents/MyVault` -3. Enter a vault name (e.g., `Personal Notes`) -4. Click **Connect Obsidian** +1. Open **Connectors** in SurfSense and choose **Obsidian**. +2. Click **Open plugin releases** and install the latest SurfSense Obsidian plugin. +3. In Obsidian, open **Settings → SurfSense**. +4. Paste your SurfSense API token from the connector setup panel. +5. Paste your SurfSense backend URL in the plugin's **Server URL** setting. +6. Choose the Search Space in the plugin, then run the first sync. +7. Confirm the connector appears as **Obsidian — ** in SurfSense. - Find your vault path: In Obsidian, right-click any note → "Reveal in Finder" (macOS) or "Show in Explorer" (Windows). + You do not create or configure a vault path in the web UI. The connector row is created automatically when the plugin calls `/api/v1/obsidian/connect`. - -Enable periodic sync to automatically re-index notes when content changes. Available frequencies: Every 5 minutes, 15 minutes, hourly, every 6 hours, daily, or weekly. - +## Self-hosted notes ---- - -## Docker Setup - -For Docker deployments, you need to mount your Obsidian vault as a volume. - -### Step 1: Update docker-compose.yml - -Add your vault as a volume mount to the SurfSense backend service: - -```yaml -services: - surfsense: - # ... other config - volumes: - - /path/to/your/obsidian/vault:/app/obsidian_vaults/my-vault:ro -``` - - - The `:ro` flag mounts the vault as read-only, which is recommended for security. - - -### Step 2: Configure the Connector - -Use the **container path** (not your local path) when setting up the connector: - -| Your Local Path | Container Path (use this) | -|-----------------|---------------------------| -| `/Users/john/Documents/MyVault` | `/app/obsidian_vaults/my-vault` | -| `C:\Users\john\Documents\MyVault` | `/app/obsidian_vaults/my-vault` | - -### Example: Multiple Vaults - -```yaml -volumes: - - /Users/john/Documents/PersonalNotes:/app/obsidian_vaults/personal:ro - - /Users/john/Documents/WorkNotes:/app/obsidian_vaults/work:ro -``` - -Then create separate connectors for each vault using `/app/obsidian_vaults/personal` and `/app/obsidian_vaults/work`. - ---- - -## Connector Configuration - -| Field | Description | Required | -|-------|-------------|----------| -| **Connector Name** | A friendly name to identify this connector | Yes | -| **Vault Path** | Absolute path to your vault (container path for Docker) | Yes | -| **Vault Name** | Display name for your vault in search results | Yes | -| **Exclude Folders** | Comma-separated folder names to skip | No | -| **Include Attachments** | Index embedded files (images, PDFs) | No | - ---- - -## Recommended Exclusions - -Common folders to exclude from indexing: - -| Folder | Reason | -|--------|--------| -| `.obsidian` | Obsidian config files (always exclude) | -| `.trash` | Obsidian's trash folder | -| `templates` | Template files you don't want searchable | -| `daily-notes` | If you want to exclude daily notes | -| `attachments` | If not using "Include Attachments" | - -Default exclusions: `.obsidian,.trash` - ---- +- Use your public or LAN backend URL that your Obsidian device can reach. +- No Docker bind mount for the vault is required. +- If your instance is behind TLS, ensure the URL/certificate is valid for the device running Obsidian. ## Troubleshooting -**Vault not found / Permission denied** -- Verify the path exists and is accessible -- For Docker: ensure the volume is mounted correctly in `docker-compose.yml` -- Check file permissions: SurfSense needs read access to the vault directory +**Plugin connects but no files appear** +- Verify the plugin is pointed to the correct Search Space. +- Trigger a manual sync from the plugin settings. +- Confirm your API token is valid and not expired. -**No notes indexed** -- Ensure your vault contains `.md` files -- Check that notes aren't in excluded folders -- Verify the path points to the vault root (contains `.obsidian` folder) +**Unauthorized / 401 errors** +- Regenerate and paste a fresh API token from SurfSense. +- Ensure the token belongs to the same account and workspace you are syncing into. -**Changes not appearing** -- Wait for the next sync cycle, or manually trigger re-indexing -- For Docker: restart the container if you modified volume mounts - -**Docker: "path not found" error** -- Use the container path (`/app/obsidian_vaults/...`), not your local path -- Verify the volume mount in `docker-compose.yml` matches +**Cannot reach server URL** +- Check that the backend URL is reachable from the Obsidian device. +- For self-hosted setups, verify firewall and reverse proxy rules. +- Avoid using localhost unless SurfSense and Obsidian run on the same machine. From d2cb778c08bf6f8dbc81d06b2d422ab8f5f51b44 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 20:48:40 +0200 Subject: [PATCH 042/299] add Gmail search and read email tools --- .../agents/new_chat/tools/gmail/__init__.py | 8 + .../agents/new_chat/tools/gmail/read_email.py | 87 ++++++++++ .../new_chat/tools/gmail/search_emails.py | 148 ++++++++++++++++++ 3 files changed, 243 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py b/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py index efb2fb0fa..294840122 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py @@ -1,6 +1,12 @@ from app.agents.new_chat.tools.gmail.create_draft import ( create_create_gmail_draft_tool, ) +from app.agents.new_chat.tools.gmail.read_email import ( + create_read_gmail_email_tool, +) +from app.agents.new_chat.tools.gmail.search_emails import ( + create_search_gmail_tool, +) from app.agents.new_chat.tools.gmail.send_email import ( create_send_gmail_email_tool, ) @@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import ( __all__ = [ "create_create_gmail_draft_tool", + "create_read_gmail_email_tool", + "create_search_gmail_tool", "create_send_gmail_email_tool", "create_trash_gmail_email_tool", "create_update_gmail_draft_tool", diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py new file mode 100644 index 000000000..9071f129a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -0,0 +1,87 @@ +import logging +from typing import Any + +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +_GMAIL_TYPES = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, +] + + +def create_read_gmail_email_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_gmail_email(message_id: str) -> dict[str, Any]: + """Read the full content of a specific Gmail email by its message ID. + + Use after search_gmail to get the complete body of an email. + + Args: + message_id: The Gmail message ID (from search_gmail results). + + Returns: + Dictionary with status and the full email content formatted as markdown. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Gmail tool not properly configured."} + + try: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + + from app.agents.new_chat.tools.gmail.search_emails import _build_credentials + + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + detail, error = await gmail.get_message_details(message_id) + if error: + if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): + return {"status": "auth_error", "message": error, "connector_type": "gmail"} + return {"status": "error", "message": error} + + if not detail: + return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."} + + content = gmail.format_message_to_markdown(detail) + + return {"status": "success", "message_id": message_id, "content": content} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Gmail email: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read email. Please try again."} + + return read_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py new file mode 100644 index 000000000..bfc328389 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -0,0 +1,148 @@ +import logging +from datetime import datetime +from typing import Any + +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +_GMAIL_TYPES = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, +] + + +def _build_credentials(connector: SearchSourceConnector): + """Build Google OAuth Credentials from a Gmail connector's config.""" + if connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: + from app.utils.google_credentials import build_composio_credentials + + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected account ID not found.") + return build_composio_credentials(cca_id) + + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + cfg = dict(connector.config) + if cfg.get("_token_encrypted") and config.SECRET_KEY: + enc = TokenEncryption(config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if cfg.get(key): + cfg[key] = enc.decrypt_token(cfg[key]) + + exp = (cfg.get("expiry") or "").replace("Z", "") + return Credentials( + token=cfg.get("token"), + refresh_token=cfg.get("refresh_token"), + token_uri=cfg.get("token_uri"), + client_id=cfg.get("client_id"), + client_secret=cfg.get("client_secret"), + scopes=cfg.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + +def create_search_gmail_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def search_gmail( + query: str, + max_results: int = 10, + ) -> dict[str, Any]: + """Search emails in the user's Gmail inbox using Gmail search syntax. + + Args: + query: Gmail search query, same syntax as the Gmail search bar. + Examples: "from:alice@example.com", "subject:meeting", + "is:unread", "after:2024/01/01 before:2024/02/01", + "has:attachment", "in:sent". + max_results: Number of emails to return (default 10, max 20). + + Returns: + Dictionary with status and a list of email summaries including + message_id, subject, from, date, snippet. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Gmail tool not properly configured."} + + max_results = min(max_results, 20) + + try: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + messages_list, error = await gmail.get_messages_list( + max_results=max_results, query=query + ) + if error: + if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): + return {"status": "auth_error", "message": error, "connector_type": "gmail"} + return {"status": "error", "message": error} + + if not messages_list: + return {"status": "success", "emails": [], "total": 0, "message": "No emails found."} + + emails = [] + for msg in messages_list: + detail, err = await gmail.get_message_details(msg["id"]) + if err: + continue + headers = { + h["name"].lower(): h["value"] + for h in detail.get("payload", {}).get("headers", []) + } + emails.append({ + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + }) + + return {"status": "success", "emails": emails, "total": len(emails)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error searching Gmail: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to search Gmail. Please try again."} + + return search_gmail From 07a5fac15d5f5a10722c9febed527bd2632e3023 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 20:48:47 +0200 Subject: [PATCH 043/299] add Calendar search events tool --- .../tools/google_calendar/__init__.py | 4 + .../tools/google_calendar/search_events.py | 148 ++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py index d1ce4e795..13d4c06cb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py @@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import ( from app.agents.new_chat.tools.google_calendar.delete_event import ( create_delete_calendar_event_tool, ) +from app.agents.new_chat.tools.google_calendar.search_events import ( + create_search_calendar_events_tool, +) from app.agents.new_chat.tools.google_calendar.update_event import ( create_update_calendar_event_tool, ) @@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import ( __all__ = [ "create_create_calendar_event_tool", "create_delete_calendar_event_tool", + "create_search_calendar_events_tool", "create_update_calendar_event_tool", ] diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py new file mode 100644 index 000000000..ad66775ef --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -0,0 +1,148 @@ +import logging +from datetime import datetime +from typing import Any + +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +_CALENDAR_TYPES = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, +] + + +def _build_credentials(connector: SearchSourceConnector): + """Build Google OAuth Credentials from a Calendar connector's config.""" + if connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: + from app.utils.google_credentials import build_composio_credentials + + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected account ID not found.") + return build_composio_credentials(cca_id) + + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + cfg = dict(connector.config) + if cfg.get("_token_encrypted") and config.SECRET_KEY: + enc = TokenEncryption(config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if cfg.get(key): + cfg[key] = enc.decrypt_token(cfg[key]) + + exp = (cfg.get("expiry") or "").replace("Z", "") + return Credentials( + token=cfg.get("token"), + refresh_token=cfg.get("refresh_token"), + token_uri=cfg.get("token_uri"), + client_id=cfg.get("client_id"), + client_secret=cfg.get("client_secret"), + scopes=cfg.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + +def create_search_calendar_events_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def search_calendar_events( + start_date: str, + end_date: str, + max_results: int = 25, + ) -> dict[str, Any]: + """Search Google Calendar events within a date range. + + Args: + start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01"). + end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30"). + max_results: Maximum number of events to return (default 25, max 50). + + Returns: + Dictionary with status and a list of events including + event_id, summary, start, end, location, attendees. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Calendar tool not properly configured."} + + max_results = min(max_results, 50) + + try: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } + + creds = _build_credentials(connector) + + from app.connectors.google_calendar_connector import GoogleCalendarConnector + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) + + if error: + if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): + return {"status": "auth_error", "message": error, "connector_type": "google_calendar"} + if "no events found" in error.lower(): + return {"status": "success", "events": [], "total": 0, "message": error} + return {"status": "error", "message": error} + + events = [] + for ev in events_raw: + start = ev.get("start", {}) + end = ev.get("end", {}) + attendees_raw = ev.get("attendees", []) + events.append({ + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [ + a.get("email", "") for a in attendees_raw[:10] + ], + "status": ev.get("status", ""), + }) + + return {"status": "success", "events": events, "total": len(events)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error searching calendar events: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to search calendar events. Please try again."} + + return search_calendar_events From 1de2517eae9b381d6fec4dd8a7ffa21f3de7ce18 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 20:49:02 +0200 Subject: [PATCH 044/299] add Discord list channels, read messages, send message tools --- .../agents/new_chat/tools/discord/__init__.py | 15 +++ .../agents/new_chat/tools/discord/_auth.py | 46 +++++++++ .../new_chat/tools/discord/list_channels.py | 67 +++++++++++++ .../new_chat/tools/discord/read_messages.py | 80 ++++++++++++++++ .../new_chat/tools/discord/send_message.py | 96 +++++++++++++++++++ 5 files changed, 304 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/tools/discord/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/discord/_auth.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/discord/send_message.py diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py b/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py new file mode 100644 index 000000000..b4eaec1f0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py @@ -0,0 +1,15 @@ +from app.agents.new_chat.tools.discord.list_channels import ( + create_list_discord_channels_tool, +) +from app.agents.new_chat.tools.discord.read_messages import ( + create_read_discord_messages_tool, +) +from app.agents.new_chat.tools.discord.send_message import ( + create_send_discord_message_tool, +) + +__all__ = [ + "create_list_discord_channels_tool", + "create_read_discord_messages_tool", + "create_send_discord_message_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py new file mode 100644 index 000000000..b369c10f1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py @@ -0,0 +1,46 @@ +"""Shared auth helper for Discord agent tools (REST API, not gateway bot).""" + +import logging + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) + +DISCORD_API = "https://discord.com/api/v10" + + +async def get_discord_connector( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> SearchSourceConnector | None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR, + ) + ) + return result.scalars().first() + + +def get_bot_token(connector: SearchSourceConnector) -> str: + """Extract and decrypt the bot token from connector config.""" + cfg = dict(connector.config) + if cfg.get("_token_encrypted") and config.SECRET_KEY: + enc = TokenEncryption(config.SECRET_KEY) + if cfg.get("bot_token"): + cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"]) + token = cfg.get("bot_token") + if not token: + raise ValueError("Discord bot token not found in connector config.") + return token + + +def get_guild_id(connector: SearchSourceConnector) -> str | None: + return connector.config.get("guild_id") diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py new file mode 100644 index 000000000..a33b88aa0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -0,0 +1,67 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id + +logger = logging.getLogger(__name__) + + +def create_list_discord_channels_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def list_discord_channels() -> dict[str, Any]: + """List text channels in the connected Discord server. + + Returns: + Dictionary with status and a list of channels (id, name). + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Discord tool not properly configured."} + + try: + connector = await get_discord_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Discord connector found."} + + guild_id = get_guild_id(connector) + if not guild_id: + return {"status": "error", "message": "No guild ID in Discord connector config."} + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/guilds/{guild_id}/channels", + headers={"Authorization": f"Bot {token}"}, + timeout=15.0, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + if resp.status_code != 200: + return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + + # Type 0 = text channel + channels = [ + {"id": ch["id"], "name": ch["name"]} + for ch in resp.json() + if ch.get("type") == 0 + ] + return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error listing Discord channels: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to list Discord channels."} + + return list_discord_channels diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py new file mode 100644 index 000000000..852a9297b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -0,0 +1,80 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import DISCORD_API, get_bot_token, get_discord_connector + +logger = logging.getLogger(__name__) + + +def create_read_discord_messages_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_discord_messages( + channel_id: str, + limit: int = 25, + ) -> dict[str, Any]: + """Read recent messages from a Discord text channel. + + Args: + channel_id: The Discord channel ID (from list_discord_channels). + limit: Number of messages to fetch (default 25, max 50). + + Returns: + Dictionary with status and a list of messages including + id, author, content, timestamp. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Discord tool not properly configured."} + + limit = min(limit, 50) + + try: + connector = await get_discord_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Discord connector found."} + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/channels/{channel_id}/messages", + headers={"Authorization": f"Bot {token}"}, + params={"limit": limit}, + timeout=15.0, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + if resp.status_code == 403: + return {"status": "error", "message": "Bot lacks permission to read this channel."} + if resp.status_code != 200: + return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + + messages = [ + { + "id": m["id"], + "author": m.get("author", {}).get("username", "Unknown"), + "content": m.get("content", ""), + "timestamp": m.get("timestamp", ""), + } + for m in resp.json() + ] + + return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Discord messages: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read Discord messages."} + + return read_discord_messages diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py new file mode 100644 index 000000000..be4e6fdb2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -0,0 +1,96 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.hitl import request_approval + +from ._auth import DISCORD_API, get_bot_token, get_discord_connector + +logger = logging.getLogger(__name__) + + +def create_send_discord_message_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def send_discord_message( + channel_id: str, + content: str, + ) -> dict[str, Any]: + """Send a message to a Discord text channel. + + Args: + channel_id: The Discord channel ID (from list_discord_channels). + content: The message text (max 2000 characters). + + Returns: + Dictionary with status, message_id on success. + + IMPORTANT: + - If status is "rejected", the user explicitly declined. Do NOT retry. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Discord tool not properly configured."} + + if len(content) > 2000: + return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."} + + try: + connector = await get_discord_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Discord connector found."} + + result = request_approval( + action_type="discord_send_message", + tool_name="send_discord_message", + params={"channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, + ) + + if result.rejected: + return {"status": "rejected", "message": "User declined. Message was not sent."} + + final_content = result.params.get("content", content) + final_channel = result.params.get("channel_id", channel_id) + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{DISCORD_API}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bot {token}", + "Content-Type": "application/json", + }, + json={"content": final_content}, + timeout=15.0, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + if resp.status_code == 403: + return {"status": "error", "message": "Bot lacks permission to send messages in this channel."} + if resp.status_code not in (200, 201): + return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to channel {final_channel}.", + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error sending Discord message: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to send Discord message."} + + return send_discord_message From 49f8d1abd449d4eb24a5db4e9de93ec850fefa32 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 20:49:50 +0200 Subject: [PATCH 045/299] add Teams list channels, read messages, send message tools --- .../agents/new_chat/tools/teams/__init__.py | 15 +++ .../app/agents/new_chat/tools/teams/_auth.py | 43 ++++++++ .../new_chat/tools/teams/list_channels.py | 77 +++++++++++++ .../new_chat/tools/teams/read_messages.py | 91 ++++++++++++++++ .../new_chat/tools/teams/send_message.py | 101 ++++++++++++++++++ .../app/routes/teams_add_connector_route.py | 1 + 6 files changed, 328 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/tools/teams/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/teams/_auth.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/teams/send_message.py diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py b/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py new file mode 100644 index 000000000..60e2add49 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py @@ -0,0 +1,15 @@ +from app.agents.new_chat.tools.teams.list_channels import ( + create_list_teams_channels_tool, +) +from app.agents.new_chat.tools.teams.read_messages import ( + create_read_teams_messages_tool, +) +from app.agents.new_chat.tools.teams.send_message import ( + create_send_teams_message_tool, +) + +__all__ = [ + "create_list_teams_channels_tool", + "create_read_teams_messages_tool", + "create_send_teams_message_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py new file mode 100644 index 000000000..989fce7c6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py @@ -0,0 +1,43 @@ +"""Shared auth helper for Teams agent tools (Microsoft Graph REST API).""" + +import logging + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) + +GRAPH_API = "https://graph.microsoft.com/v1.0" + + +async def get_teams_connector( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> SearchSourceConnector | None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR, + ) + ) + return result.scalars().first() + + +async def get_access_token( + db_session: AsyncSession, + connector: SearchSourceConnector, +) -> str: + """Get a valid Microsoft Graph access token, refreshing if expired.""" + from app.connectors.teams_connector import TeamsConnector + + tc = TeamsConnector( + session=db_session, + connector_id=connector.id, + ) + return await tc._get_valid_token() diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py new file mode 100644 index 000000000..a676595c1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -0,0 +1,77 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import GRAPH_API, get_access_token, get_teams_connector + +logger = logging.getLogger(__name__) + + +def create_list_teams_channels_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def list_teams_channels() -> dict[str, Any]: + """List all Microsoft Teams and their channels the user has access to. + + Returns: + Dictionary with status and a list of teams, each containing + team_id, team_name, and a list of channels (id, name). + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Teams tool not properly configured."} + + try: + connector = await get_teams_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Teams connector found."} + + token = await get_access_token(db_session, connector) + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient(timeout=20.0) as client: + teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers) + + if teams_resp.status_code == 401: + return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + if teams_resp.status_code != 200: + return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"} + + teams_data = teams_resp.json().get("value", []) + result_teams = [] + + async with httpx.AsyncClient(timeout=20.0) as client: + for team in teams_data: + team_id = team["id"] + ch_resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels", + headers=headers, + ) + channels = [] + if ch_resp.status_code == 200: + channels = [ + {"id": ch["id"], "name": ch.get("displayName", "")} + for ch in ch_resp.json().get("value", []) + ] + result_teams.append({ + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + }) + + return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error listing Teams channels: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to list Teams channels."} + + return list_teams_channels diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py new file mode 100644 index 000000000..90896cb95 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -0,0 +1,91 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import GRAPH_API, get_access_token, get_teams_connector + +logger = logging.getLogger(__name__) + + +def create_read_teams_messages_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_teams_messages( + team_id: str, + channel_id: str, + limit: int = 25, + ) -> dict[str, Any]: + """Read recent messages from a Microsoft Teams channel. + + Args: + team_id: The team ID (from list_teams_channels). + channel_id: The channel ID (from list_teams_channels). + limit: Number of messages to fetch (default 25, max 50). + + Returns: + Dictionary with status and a list of messages including + id, sender, content, timestamp. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Teams tool not properly configured."} + + limit = min(limit, 50) + + try: + connector = await get_teams_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Teams connector found."} + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", + headers={"Authorization": f"Bearer {token}"}, + params={"$top": limit}, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + if resp.status_code == 403: + return {"status": "error", "message": "Insufficient permissions to read this channel."} + if resp.status_code != 200: + return {"status": "error", "message": f"Graph API error: {resp.status_code}"} + + raw_msgs = resp.json().get("value", []) + messages = [] + for m in raw_msgs: + sender = m.get("from", {}) + user_info = sender.get("user", {}) if sender else {} + body = m.get("body", {}) + messages.append({ + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + }) + + return { + "status": "success", + "team_id": team_id, + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Teams messages: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read Teams messages."} + + return read_teams_messages diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py new file mode 100644 index 000000000..ba3a515d9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -0,0 +1,101 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.hitl import request_approval + +from ._auth import GRAPH_API, get_access_token, get_teams_connector + +logger = logging.getLogger(__name__) + + +def create_send_teams_message_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def send_teams_message( + team_id: str, + channel_id: str, + content: str, + ) -> dict[str, Any]: + """Send a message to a Microsoft Teams channel. + + Requires the ChannelMessage.Send OAuth scope. If the user gets a + permission error, they may need to re-authenticate with updated scopes. + + Args: + team_id: The team ID (from list_teams_channels). + channel_id: The channel ID (from list_teams_channels). + content: The message text (HTML supported). + + Returns: + Dictionary with status, message_id on success. + + IMPORTANT: + - If status is "rejected", the user explicitly declined. Do NOT retry. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Teams tool not properly configured."} + + try: + connector = await get_teams_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Teams connector found."} + + result = request_approval( + action_type="teams_send_message", + tool_name="send_teams_message", + params={"team_id": team_id, "channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, + ) + + if result.rejected: + return {"status": "rejected", "message": "User declined. Message was not sent."} + + final_content = result.params.get("content", content) + final_team = result.params.get("team_id", team_id) + final_channel = result.params.get("channel_id", channel_id) + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"body": {"content": final_content}}, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + if resp.status_code == 403: + return { + "status": "insufficient_permissions", + "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", + } + if resp.status_code not in (200, 201): + return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"} + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to Teams channel.", + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error sending Teams message: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to send Teams message."} + + return send_teams_message diff --git a/surfsense_backend/app/routes/teams_add_connector_route.py b/surfsense_backend/app/routes/teams_add_connector_route.py index 4442307ba..bbaae3a5f 100644 --- a/surfsense_backend/app/routes/teams_add_connector_route.py +++ b/surfsense_backend/app/routes/teams_add_connector_route.py @@ -45,6 +45,7 @@ SCOPES = [ "Team.ReadBasic.All", # Read basic team information "Channel.ReadBasic.All", # Read basic channel information "ChannelMessage.Read.All", # Read messages in channels + "ChannelMessage.Send", # Send messages in channels ] # Initialize security utilities From ba8e3133b9281c07ab366039a1eb36c8e231afe8 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 20:50:31 +0200 Subject: [PATCH 046/299] add Luma list events, read event, create event tools --- .../agents/new_chat/tools/luma/__init__.py | 15 +++ .../app/agents/new_chat/tools/luma/_auth.py | 42 +++++++ .../new_chat/tools/luma/create_event.py | 116 ++++++++++++++++++ .../agents/new_chat/tools/luma/list_events.py | 100 +++++++++++++++ .../agents/new_chat/tools/luma/read_event.py | 82 +++++++++++++ 5 files changed, 355 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/tools/luma/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/luma/_auth.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/luma/create_event.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/luma/list_events.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/luma/read_event.py diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py b/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py new file mode 100644 index 000000000..255119bee --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py @@ -0,0 +1,15 @@ +from app.agents.new_chat.tools.luma.create_event import ( + create_create_luma_event_tool, +) +from app.agents.new_chat.tools.luma.list_events import ( + create_list_luma_events_tool, +) +from app.agents.new_chat.tools.luma.read_event import ( + create_read_luma_event_tool, +) + +__all__ = [ + "create_create_luma_event_tool", + "create_list_luma_events_tool", + "create_read_luma_event_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py new file mode 100644 index 000000000..ef2fa8540 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py @@ -0,0 +1,42 @@ +"""Shared auth helper for Luma agent tools.""" + +import logging + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +LUMA_API = "https://public-api.luma.com/v1" + + +async def get_luma_connector( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> SearchSourceConnector | None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR, + ) + ) + return result.scalars().first() + + +def get_api_key(connector: SearchSourceConnector) -> str: + """Extract the API key from connector config (handles both key names).""" + key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY") + if not key: + raise ValueError("Luma API key not found in connector config.") + return key + + +def luma_headers(api_key: str) -> dict[str, str]: + return { + "Content-Type": "application/json", + "x-luma-api-key": api_key, + } diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py new file mode 100644 index 000000000..2217d29e6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -0,0 +1,116 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.hitl import request_approval + +from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers + +logger = logging.getLogger(__name__) + + +def create_create_luma_event_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def create_luma_event( + name: str, + start_at: str, + end_at: str, + description: str | None = None, + timezone: str = "UTC", + ) -> dict[str, Any]: + """Create a new event on Luma. + + Args: + name: The event title. + start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00"). + end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00"). + description: Optional event description (markdown supported). + timezone: Timezone string (default "UTC", e.g. "America/New_York"). + + Returns: + Dictionary with status, event_id on success. + + IMPORTANT: + - If status is "rejected", the user explicitly declined. Do NOT retry. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Luma tool not properly configured."} + + try: + connector = await get_luma_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Luma connector found."} + + result = request_approval( + action_type="luma_create_event", + tool_name="create_luma_event", + params={ + "name": name, + "start_at": start_at, + "end_at": end_at, + "description": description, + "timezone": timezone, + }, + context={"connector_id": connector.id}, + ) + + if result.rejected: + return {"status": "rejected", "message": "User declined. Event was not created."} + + final_name = result.params.get("name", name) + final_start = result.params.get("start_at", start_at) + final_end = result.params.get("end_at", end_at) + final_desc = result.params.get("description", description) + final_tz = result.params.get("timezone", timezone) + + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + body: dict[str, Any] = { + "name": final_name, + "start_at": final_start, + "end_at": final_end, + "timezone": final_tz, + } + if final_desc: + body["description_md"] = final_desc + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{LUMA_API}/event/create", + headers=headers, + json=body, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + if resp.status_code == 403: + return {"status": "error", "message": "Luma Plus subscription required to create events via API."} + if resp.status_code not in (200, 201): + return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"} + + data = resp.json() + event_id = data.get("api_id") or data.get("event", {}).get("api_id") + + return { + "status": "success", + "event_id": event_id, + "message": f"Event '{final_name}' created on Luma.", + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error creating Luma event: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to create Luma event."} + + return create_luma_event diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py new file mode 100644 index 000000000..cd4721758 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -0,0 +1,100 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers + +logger = logging.getLogger(__name__) + + +def create_list_luma_events_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def list_luma_events( + max_results: int = 25, + ) -> dict[str, Any]: + """List upcoming and recent Luma events. + + Args: + max_results: Maximum events to return (default 25, max 50). + + Returns: + Dictionary with status and a list of events including + event_id, name, start_at, end_at, location, url. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Luma tool not properly configured."} + + max_results = min(max_results, 50) + + try: + connector = await get_luma_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Luma connector found."} + + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + all_entries: list[dict] = [] + cursor = None + + async with httpx.AsyncClient(timeout=20.0) as client: + while len(all_entries) < max_results: + params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))} + if cursor: + params["cursor"] = cursor + + resp = await client.get( + f"{LUMA_API}/calendar/list-events", + headers=headers, + params=params, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + if resp.status_code != 200: + return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + + data = resp.json() + entries = data.get("entries", []) + if not entries: + break + all_entries.extend(entries) + + next_cursor = data.get("next_cursor") + if not next_cursor: + break + cursor = next_cursor + + events = [] + for entry in all_entries[:max_results]: + ev = entry.get("event", {}) + geo = ev.get("geo_info", {}) + events.append({ + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + }) + + return {"status": "success", "events": events, "total": len(events)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error listing Luma events: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to list Luma events."} + + return list_luma_events diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py new file mode 100644 index 000000000..eb3ac55c6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -0,0 +1,82 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers + +logger = logging.getLogger(__name__) + + +def create_read_luma_event_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_luma_event(event_id: str) -> dict[str, Any]: + """Read detailed information about a specific Luma event. + + Args: + event_id: The Luma event API ID (from list_luma_events). + + Returns: + Dictionary with status and full event details including + description, attendees count, meeting URL. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Luma tool not properly configured."} + + try: + connector = await get_luma_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Luma connector found."} + + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + f"{LUMA_API}/events/{event_id}", + headers=headers, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + if resp.status_code == 404: + return {"status": "not_found", "message": f"Event '{event_id}' not found."} + if resp.status_code != 200: + return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + + data = resp.json() + ev = data.get("event", data) + geo = ev.get("geo_info", {}) + + event_detail = { + "event_id": event_id, + "name": ev.get("name", ""), + "description": ev.get("description", ""), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location_name": geo.get("name", ""), + "address": geo.get("address", ""), + "url": ev.get("url", ""), + "meeting_url": ev.get("meeting_url", ""), + "visibility": ev.get("visibility", ""), + "cover_url": ev.get("cover_url", ""), + } + + return {"status": "success", "event": event_detail} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Luma event: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read Luma event."} + + return read_luma_event From 575b2c64d7a20f1a4673f7e2866515dca240e138 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 20:50:42 +0200 Subject: [PATCH 047/299] register all new live connector tools in registry --- .../app/agents/new_chat/tools/registry.py | 166 +++++++++++++++++- 1 file changed, 164 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 6f7a5a03f..f74b4271f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -50,6 +50,11 @@ from .confluence import ( create_delete_confluence_page_tool, create_update_confluence_page_tool, ) +from .discord import ( + create_list_discord_channels_tool, + create_read_discord_messages_tool, + create_send_discord_message_tool, +) from .dropbox import ( create_create_dropbox_file_tool, create_delete_dropbox_file_tool, @@ -57,6 +62,8 @@ from .dropbox import ( from .generate_image import create_generate_image_tool from .gmail import ( create_create_gmail_draft_tool, + create_read_gmail_email_tool, + create_search_gmail_tool, create_send_gmail_email_tool, create_trash_gmail_email_tool, create_update_gmail_draft_tool, @@ -64,6 +71,7 @@ from .gmail import ( from .google_calendar import ( create_create_calendar_event_tool, create_delete_calendar_event_tool, + create_search_calendar_events_tool, create_update_calendar_event_tool, ) from .google_drive import ( @@ -80,6 +88,11 @@ from .linear import ( create_delete_linear_issue_tool, create_update_linear_issue_tool, ) +from .luma import ( + create_create_luma_event_tool, + create_list_luma_events_tool, + create_read_luma_event_tool, +) from .mcp_tool import load_mcp_tools from .notion import ( create_create_notion_page_tool, @@ -95,6 +108,11 @@ from .report import create_generate_report_tool from .resume import create_generate_resume_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool +from .teams import ( + create_list_teams_channels_tool, + create_read_teams_messages_tool, + create_send_teams_message_tool, +) from .update_memory import create_update_memory_tool, create_update_team_memory_tool from .video_presentation import create_generate_video_presentation_tool from .web_search import create_web_search_tool @@ -403,9 +421,20 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ required_connector="ONEDRIVE_FILE", ), # ========================================================================= - # GOOGLE CALENDAR TOOLS - create, update, delete events + # GOOGLE CALENDAR TOOLS - search, create, update, delete events # Auto-disabled when no Google Calendar connector is configured # ========================================================================= + ToolDefinition( + name="search_calendar_events", + description="Search Google Calendar events within a date range", + factory=lambda deps: create_search_calendar_events_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_CALENDAR_CONNECTOR", + ), ToolDefinition( name="create_calendar_event", description="Create a new event on Google Calendar", @@ -440,9 +469,31 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ required_connector="GOOGLE_CALENDAR_CONNECTOR", ), # ========================================================================= - # GMAIL TOOLS - create drafts, update drafts, send emails, trash emails + # GMAIL TOOLS - search, read, create drafts, update drafts, send, trash # Auto-disabled when no Gmail connector is configured # ========================================================================= + ToolDefinition( + name="search_gmail", + description="Search emails in Gmail using Gmail search syntax", + factory=lambda deps: create_search_gmail_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", + ), + ToolDefinition( + name="read_gmail_email", + description="Read the full content of a specific Gmail email", + factory=lambda deps: create_read_gmail_email_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", + ), ToolDefinition( name="create_gmail_draft", description="Create a draft email in Gmail", @@ -561,6 +612,117 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", ), + # ========================================================================= + # DISCORD TOOLS - list channels, read messages, send messages + # Auto-disabled when no Discord connector is configured + # ========================================================================= + ToolDefinition( + name="list_discord_channels", + description="List text channels in the connected Discord server", + factory=lambda deps: create_list_discord_channels_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="DISCORD_CONNECTOR", + ), + ToolDefinition( + name="read_discord_messages", + description="Read recent messages from a Discord text channel", + factory=lambda deps: create_read_discord_messages_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="DISCORD_CONNECTOR", + ), + ToolDefinition( + name="send_discord_message", + description="Send a message to a Discord text channel", + factory=lambda deps: create_send_discord_message_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="DISCORD_CONNECTOR", + ), + # ========================================================================= + # TEAMS TOOLS - list channels, read messages, send messages + # Auto-disabled when no Teams connector is configured + # ========================================================================= + ToolDefinition( + name="list_teams_channels", + description="List Microsoft Teams and their channels", + factory=lambda deps: create_list_teams_channels_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="TEAMS_CONNECTOR", + ), + ToolDefinition( + name="read_teams_messages", + description="Read recent messages from a Microsoft Teams channel", + factory=lambda deps: create_read_teams_messages_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="TEAMS_CONNECTOR", + ), + ToolDefinition( + name="send_teams_message", + description="Send a message to a Microsoft Teams channel", + factory=lambda deps: create_send_teams_message_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="TEAMS_CONNECTOR", + ), + # ========================================================================= + # LUMA TOOLS - list events, read event details, create events + # Auto-disabled when no Luma connector is configured + # ========================================================================= + ToolDefinition( + name="list_luma_events", + description="List upcoming and recent Luma events", + factory=lambda deps: create_list_luma_events_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="LUMA_CONNECTOR", + ), + ToolDefinition( + name="read_luma_event", + description="Read detailed information about a specific Luma event", + factory=lambda deps: create_read_luma_event_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="LUMA_CONNECTOR", + ), + ToolDefinition( + name="create_luma_event", + description="Create a new event on Luma", + factory=lambda deps: create_create_luma_event_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="LUMA_CONNECTOR", + ), ] From 22f8cb2cd31cee6a82aca0c1f2dc258bc4e0f870 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 00:24:26 +0530 Subject: [PATCH 048/299] feat: enhance obsidian connector doc and add notes for migration from legacy obsidian connector --- .../components/obsidian-config.tsx | 56 +++++++++++++++++-- .../content/docs/connectors/index.mdx | 2 +- .../content/docs/connectors/obsidian.mdx | 31 +++++----- 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx index cfe6f0574..33a7110c0 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx @@ -1,11 +1,13 @@ "use client"; -import { Info } from "lucide-react"; +import { AlertTriangle, Info } from "lucide-react"; import { type FC, useEffect, useMemo, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { connectorsApiService, type ObsidianStats } from "@/lib/apis/connectors-api.service"; import type { ConnectorConfigProps } from "../index"; +const OBSIDIAN_SETUP_DOCS_URL = "/docs/connectors/obsidian"; + function formatTimestamp(value: unknown): string { if (typeof value !== "string" || !value) return "—"; const d = new Date(value); @@ -22,17 +24,61 @@ function formatTimestamp(value: unknown): string { * web UI doesn't expose a Name input or a Save button for Obsidian (the * latter is suppressed in `connector-edit-view.tsx`). * - * Renders plugin stats when connector metadata comes from the plugin. - * If metadata is missing or malformed, we show a recovery hint. + * Renders one of three modes depending on the connector's `config`: + * + * 1. **Plugin connector** (`config.source === "plugin"`) — read-only stats + * panel showing what the plugin most recently reported. + * 2. **Legacy server-path connector** (`config.legacy === true`, set by the + * migration) — migration warning + docs link + explicit disconnect data-loss + * warning so users move to the plugin flow safely. + * 3. **Unknown** — fallback for rows that escaped migration; suggests a + * clean re-install. */ export const ObsidianConfig: FC = ({ connector }) => { const config = (connector.config ?? {}) as Record; + const isLegacy = config.legacy === true; const isPlugin = config.source === "plugin"; + if (isLegacy) return ; if (isPlugin) return ; return ; }; +const LegacyBanner: FC = () => { + return ( +
+ + + + Sync stopped, install the plugin to migrate + + + This Obsidian connector used the legacy server-path scanner, which has been removed. The + notes already indexed remain searchable, but they no longer reflect changes made in your + vault. + + + +
+

Migration required

+

+ Follow the{" "} + + Obsidian setup guide + {" "} + to reconnect this vault through the plugin. +

+

+ Heads up: Disconnect also deletes every document this connector previously indexed. +

+
+
+ ); +}; + const PluginStats: FC<{ config: Record }> = ({ config }) => { const vaultId = typeof config.vault_id === "string" ? config.vault_id : null; const [stats, setStats] = useState(null); @@ -114,8 +160,8 @@ const UnknownConnectorState: FC = () => ( Unrecognized config - This connector is missing plugin metadata. Delete it, then reconnect your vault from the - SurfSense Obsidian plugin so sync can resume. + This connector has neither plugin metadata nor a legacy marker. It may predate migration — + you can safely delete it and re-install the SurfSense Obsidian plugin to resume syncing.
); diff --git a/surfsense_web/content/docs/connectors/index.mdx b/surfsense_web/content/docs/connectors/index.mdx index e3d06aa3c..ef8d214ef 100644 --- a/surfsense_web/content/docs/connectors/index.mdx +++ b/surfsense_web/content/docs/connectors/index.mdx @@ -105,7 +105,7 @@ Connect SurfSense to your favorite tools and services. Browse the available inte /> ** in SurfSense. +4. Paste your SurfSense API token from the user settings section. +5. Paste your Server URL in the plugin setting: either your SurfSense main domain (if `/api/v1` rewrites are enabled) or your direct backend URL. +6. Choose the Search Space in the plugin, then the first sync should run automatically. +7. Confirm the connector appears as **Obsidian — <vault>** in SurfSense. - - You do not create or configure a vault path in the web UI. The connector row is created automatically when the plugin calls `/api/v1/obsidian/connect`. +## Migrating from the legacy connector + +If you previously used the legacy Obsidian connector architecture, migrate to the plugin flow: + +1. Delete the old legacy Obsidian connector from SurfSense. +2. Install and configure the SurfSense Obsidian plugin using the quick start above. +3. Run the first plugin sync and verify the new **Obsidian — <vault>** connector is active. + + + Deleting the legacy connector also deletes all documents that were indexed by that connector. Always finish and verify plugin sync before deleting the old connector. -## Self-hosted notes - -- Use your public or LAN backend URL that your Obsidian device can reach. -- No Docker bind mount for the vault is required. -- If your instance is behind TLS, ensure the URL/certificate is valid for the device running Obsidian. - ## Troubleshooting **Plugin connects but no files appear** @@ -50,6 +51,10 @@ This works for cloud and self-hosted deployments, including desktop and mobile c - Trigger a manual sync from the plugin settings. - Confirm your API token is valid and not expired. +**Self-hosted URL issues** +- Use a public or LAN backend URL that your Obsidian device can reach. +- If your instance is behind TLS, ensure the URL/certificate is valid for the device running Obsidian. + **Unauthorized / 401 errors** - Regenerate and paste a fresh API token from SurfSense. - Ensure the token belongs to the same account and workspace you are syncing into. From 08489dbd5a5cda9030185940a6148de5cf52ef48 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 00:48:07 +0530 Subject: [PATCH 049/299] chore: update obsidian GitHub Actions workflows to use latest action versions --- .github/workflows/obsidian-plugin-lint.yml | 4 +-- .github/workflows/release-obsidian-plugin.yml | 34 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/obsidian-plugin-lint.yml b/.github/workflows/obsidian-plugin-lint.yml index 237087d39..80a49c3f7 100644 --- a/.github/workflows/obsidian-plugin-lint.yml +++ b/.github/workflows/obsidian-plugin-lint.yml @@ -31,9 +31,9 @@ jobs: node-version: [20.x, 22.x] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v6 with: node-version: ${{ matrix.node-version }} cache: npm diff --git a/.github/workflows/release-obsidian-plugin.yml b/.github/workflows/release-obsidian-plugin.yml index c97d45023..198b87611 100644 --- a/.github/workflows/release-obsidian-plugin.yml +++ b/.github/workflows/release-obsidian-plugin.yml @@ -1,9 +1,6 @@ name: Release Obsidian Plugin -# Triggered on tags of the form `obsidian-v0.1.0`. The version after the -# prefix MUST exactly equal `surfsense_obsidian/manifest.json`'s `version` -# (no leading `v`) — this is what BRAT and the Obsidian community plugin -# store both verify. +# Tag format: `obsidian-v` and `` must match `surfsense_obsidian/manifest.json` exactly. on: push: tags: @@ -26,14 +23,14 @@ jobs: working-directory: surfsense_obsidian steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: # Need write access for the manifest/versions.json mirror commit # back to main further down. fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v6 with: node-version: 20.x cache: npm @@ -42,7 +39,15 @@ jobs: - name: Resolve plugin version id: version run: | - tag="${GITHUB_REF_NAME:-${{ github.event.inputs.tag }}}" + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + tag="${{ github.event.inputs.tag }}" + else + tag="${GITHUB_REF_NAME}" + fi + if [ -z "$tag" ] || [[ "$tag" != obsidian-v* ]]; then + echo "::error::Invalid tag '$tag'. Expected format: obsidian-v" + exit 1 + fi version="${tag#obsidian-v}" manifest_version=$(node -p "require('./manifest.json').version") if [ "$version" != "$manifest_version" ]; then @@ -79,19 +84,14 @@ jobs: git add manifest.json versions.json git commit -m "chore(obsidian-plugin): mirror manifest+versions for ${{ steps.version.outputs.tag }}" # Push to the default branch so Obsidian can fetch raw files from HEAD. - git push origin HEAD:${{ github.event.repository.default_branch }} + if ! git push origin HEAD:${{ github.event.repository.default_branch }}; then + echo "::warning::Failed to push mirrored manifest/versions to default branch (likely branch protection). Continuing release." + fi - # IMPORTANT: BRAT and the Obsidian community plugin store look up the - # release by the bare manifest `version` (e.g. `0.1.0`), NOT by the - # build-trigger tag (`obsidian-v0.1.0`). So we publish the GitHub - # release with `tag_name: ` — `softprops/action-gh-release` - # will create that tag if it doesn't already exist, pointing at the - # commit referenced by the build-trigger tag. Verified against - # https://github.com/khoj-ai/khoj/releases (their tags are bare - # versions like `2.0.0-beta.28`, no prefix). + # Publish release under bare `manifest.json` version (no `obsidian-v` prefix) for BRAT/store compatibility. - name: Create GitHub release if: github.event_name == 'push' - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 with: tag_name: ${{ steps.version.outputs.version }} name: SurfSense Obsidian Plugin ${{ steps.version.outputs.version }} From 7133655eebd04453366f6f416d417c075d0ca841 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:19:08 +0200 Subject: [PATCH 050/299] add MCP service registry for Linear, Jira, ClickUp --- .../app/services/mcp_oauth/__init__.py | 0 .../app/services/mcp_oauth/registry.py | 41 +++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 surfsense_backend/app/services/mcp_oauth/__init__.py create mode 100644 surfsense_backend/app/services/mcp_oauth/registry.py diff --git a/surfsense_backend/app/services/mcp_oauth/__init__.py b/surfsense_backend/app/services/mcp_oauth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py new file mode 100644 index 000000000..93d5d5448 --- /dev/null +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -0,0 +1,41 @@ +"""Registry of MCP services with OAuth 2.1 support. + +Each entry maps a URL-safe service key to its MCP server endpoint and +authentication strategy. Services with ``supports_dcr=True`` will use +RFC 7591 Dynamic Client Registration; the rest require pre-configured +credentials via environment variables. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class MCPServiceConfig: + name: str + mcp_url: str + supports_dcr: bool = True + client_id_env: str | None = None + client_secret_env: str | None = None + scopes: list[str] = field(default_factory=list) + + +MCP_SERVICES: dict[str, MCPServiceConfig] = { + "linear": MCPServiceConfig( + name="Linear", + mcp_url="https://mcp.linear.app/mcp", + ), + "jira": MCPServiceConfig( + name="Jira", + mcp_url="https://mcp.atlassian.com/v1/mcp", + ), + "clickup": MCPServiceConfig( + name="ClickUp", + mcp_url="https://mcp.clickup.com/mcp", + ), +} + + +def get_service(key: str) -> MCPServiceConfig | None: + return MCP_SERVICES.get(key) From 4efdee5aed65bc41f61f24d37f1502fe4ece5bc4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:19:15 +0200 Subject: [PATCH 051/299] add MCP OAuth discovery, DCR, and token exchange --- .../app/services/mcp_oauth/discovery.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 surfsense_backend/app/services/mcp_oauth/discovery.py diff --git a/surfsense_backend/app/services/mcp_oauth/discovery.py b/surfsense_backend/app/services/mcp_oauth/discovery.py new file mode 100644 index 000000000..e8bcd7076 --- /dev/null +++ b/surfsense_backend/app/services/mcp_oauth/discovery.py @@ -0,0 +1,111 @@ +"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange.""" + +from __future__ import annotations + +import base64 +import logging +from urllib.parse import urlparse + +import httpx + +logger = logging.getLogger(__name__) + + +async def discover_oauth_metadata(mcp_url: str, *, timeout: float = 15.0) -> dict: + """Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint. + + Per the MCP spec the discovery document lives at the *origin* of the + MCP server URL, not at the MCP endpoint path. + """ + parsed = urlparse(mcp_url) + origin = f"{parsed.scheme}://{parsed.netloc}" + discovery_url = f"{origin}/.well-known/oauth-authorization-server" + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(discovery_url, timeout=timeout) + resp.raise_for_status() + return resp.json() + + +async def register_client( + registration_endpoint: str, + redirect_uri: str, + *, + client_name: str = "SurfSense", + timeout: float = 15.0, +) -> dict: + """Perform Dynamic Client Registration (RFC 7591).""" + payload = { + "client_name": client_name, + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + } + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.post( + registration_endpoint, json=payload, timeout=timeout, + ) + resp.raise_for_status() + return resp.json() + + +async def exchange_code_for_tokens( + token_endpoint: str, + code: str, + redirect_uri: str, + client_id: str, + client_secret: str, + code_verifier: str, + *, + timeout: float = 30.0, +) -> dict: + """Exchange an authorization code for access + refresh tokens.""" + creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.post( + token_endpoint, + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": f"Basic {creds}", + }, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json() + + +async def refresh_access_token( + token_endpoint: str, + refresh_token: str, + client_id: str, + client_secret: str, + *, + timeout: float = 30.0, +) -> dict: + """Refresh an expired access token.""" + creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.post( + token_endpoint, + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": f"Basic {creds}", + }, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json() From 45867e5c56a81bf0c308eae56bc8b274d87a980a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:19:32 +0200 Subject: [PATCH 052/299] add generic MCP OAuth route with DCR + PKCE --- .../app/routes/mcp_oauth_route.py | 508 ++++++++++++++++++ 1 file changed, 508 insertions(+) create mode 100644 surfsense_backend/app/routes/mcp_oauth_route.py diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py new file mode 100644 index 000000000..689914ee8 --- /dev/null +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -0,0 +1,508 @@ +"""Generic MCP OAuth 2.1 route for services with official MCP servers. + +Handles the full flow: discovery → DCR → PKCE authorization → token exchange +→ MCP_CONNECTOR creation. Currently supports Linear, Jira, and ClickUp. +""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime, timedelta +from urllib.parse import urlencode +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import RedirectResponse +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.connector_naming import generate_unique_connector_name +from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair + +logger = logging.getLogger(__name__) + +router = APIRouter() + +_state_manager: OAuthStateManager | None = None +_token_encryption: TokenEncryption | None = None + + +def _get_state_manager() -> OAuthStateManager: + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise HTTPException(status_code=500, detail="SECRET_KEY not configured.") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def _get_token_encryption() -> TokenEncryption: + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise HTTPException(status_code=500, detail="SECRET_KEY not configured.") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +def _build_redirect_uri(service: str) -> str: + base = config.BACKEND_URL + if not base: + raise HTTPException(status_code=500, detail="BACKEND_URL not configured.") + return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback" + + +def _frontend_redirect( + space_id: int | None, + *, + success: bool = False, + connector_id: int | None = None, + error: str | None = None, + service: str = "mcp", +) -> RedirectResponse: + if success and space_id: + qs = f"success=true&connector={service}-mcp-connector" + if connector_id: + qs += f"&connectorId={connector_id}" + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}" + ) + if error and space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}" + ) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard") + + +# --------------------------------------------------------------------------- +# /add — start MCP OAuth flow +# --------------------------------------------------------------------------- + +@router.get("/auth/mcp/{service}/connector/add") +async def connect_mcp_service( + service: str, + space_id: int, + user: User = Depends(current_active_user), +): + from app.services.mcp_oauth.registry import get_service + + svc = get_service(service) + if not svc: + raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}") + + try: + from app.services.mcp_oauth.discovery import ( + discover_oauth_metadata, + register_client, + ) + + metadata = await discover_oauth_metadata(svc.mcp_url) + auth_endpoint = metadata.get("authorization_endpoint") + token_endpoint = metadata.get("token_endpoint") + registration_endpoint = metadata.get("registration_endpoint") + + if not auth_endpoint or not token_endpoint: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server returned incomplete OAuth metadata.", + ) + + redirect_uri = _build_redirect_uri(service) + + if svc.supports_dcr and registration_endpoint: + dcr = await register_client(registration_endpoint, redirect_uri) + client_id = dcr.get("client_id") + client_secret = dcr.get("client_secret", "") + if not client_id: + raise HTTPException( + status_code=502, + detail=f"DCR for {svc.name} did not return a client_id.", + ) + elif not svc.supports_dcr and svc.client_id_env: + client_id = getattr(config, svc.client_id_env, None) + client_secret = getattr(config, svc.client_secret_env or "", None) or "" + if not client_id: + raise HTTPException( + status_code=500, + detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).", + ) + else: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server has no DCR and no fallback credentials.", + ) + + verifier, challenge = generate_pkce_pair() + enc = _get_token_encryption() + + state = _get_state_manager().generate_secure_state( + space_id, + user.id, + service=service, + code_verifier=verifier, + mcp_client_id=client_id, + mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "", + mcp_token_endpoint=token_endpoint, + mcp_url=svc.mcp_url, + ) + + auth_params: dict[str, str] = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": redirect_uri, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + } + if svc.scopes: + auth_params["scope"] = " ".join(svc.scopes) + + auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" + + logger.info( + "Generated %s MCP OAuth URL for user %s, space %s", + svc.name, user.id, space_id, + ) + return {"auth_url": auth_url} + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate {service} MCP OAuth: {e!s}", + ) from e + + +# --------------------------------------------------------------------------- +# /callback — handle OAuth redirect +# --------------------------------------------------------------------------- + +@router.get("/auth/mcp/{service}/connector/callback") +async def mcp_oauth_callback( + service: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + if error: + logger.warning("%s MCP OAuth error: %s", service, error) + space_id = None + if state: + try: + data = _get_state_manager().validate_state(state) + space_id = data.get("space_id") + except Exception: + pass + return _frontend_redirect( + space_id, error=f"{service}_mcp_oauth_denied", service=service, + ) + + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + data = _get_state_manager().validate_state(state) + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + svc_key = data.get("service", service) + + from app.services.mcp_oauth.registry import get_service + + svc = get_service(svc_key) + if not svc: + raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}") + + try: + from app.services.mcp_oauth.discovery import exchange_code_for_tokens + + enc = _get_token_encryption() + client_id = data["mcp_client_id"] + client_secret = ( + enc.decrypt_token(data["mcp_client_secret"]) + if data.get("mcp_client_secret") + else "" + ) + token_endpoint = data["mcp_token_endpoint"] + code_verifier = data["code_verifier"] + mcp_url = data["mcp_url"] + redirect_uri = _build_redirect_uri(service) + + token_json = await exchange_code_for_tokens( + token_endpoint=token_endpoint, + code=code, + redirect_uri=redirect_uri, + client_id=client_id, + client_secret=client_secret, + code_verifier=code_verifier, + ) + + access_token = token_json.get("access_token") + if not access_token: + raise HTTPException( + status_code=400, + detail=f"No access token received from {svc.name}.", + ) + + refresh_token = token_json.get("refresh_token") + expires_at = None + if token_json.get("expires_in"): + expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + connector_config = { + "server_config": { + "transport": "streamable-http", + "url": mcp_url, + "headers": {"Authorization": f"Bearer {access_token}"}, + }, + "mcp_service": svc_key, + "mcp_oauth": { + "client_id": client_id, + "client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "token_endpoint": token_endpoint, + "access_token": enc.encrypt_token(access_token), + "refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None, + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + }, + "_token_encrypted": True, + } + + # ---- Re-auth path ---- + reauth_connector_id = data.get("connector_id") + if reauth_connector_id: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == reauth_connector_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + db_connector = result.scalars().first() + if not db_connector: + raise HTTPException( + status_code=404, + detail="Connector not found during re-auth", + ) + + db_connector.config = connector_config + flag_modified(db_connector, "config") + await session.commit() + await session.refresh(db_connector) + + _invalidate_cache(space_id) + + logger.info( + "Re-authenticated %s MCP connector %s for user %s", + svc.name, db_connector.id, user_id, + ) + reauth_return_url = data.get("return_url") + if reauth_return_url and reauth_return_url.startswith("/"): + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" + ) + return _frontend_redirect( + space_id, success=True, connector_id=db_connector.id, service=service, + ) + + # ---- New connector path ---- + connector_name = await generate_unique_connector_name( + session, + SearchSourceConnectorType.MCP_CONNECTOR, + space_id, + user_id, + f"{svc.name} MCP", + ) + + new_connector = SearchSourceConnector( + name=connector_name, + connector_type=SearchSourceConnectorType.MCP_CONNECTOR, + is_indexable=False, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + + try: + await session.commit() + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, detail=f"Database integrity error: {e!s}", + ) from e + + _invalidate_cache(space_id) + + logger.info( + "Created %s MCP connector %s for user %s in space %s", + svc.name, new_connector.id, user_id, space_id, + ) + return _frontend_redirect( + space_id, success=True, connector_id=new_connector.id, service=service, + ) + + except HTTPException: + raise + except Exception as e: + logger.error( + "Failed to complete %s MCP OAuth: %s", service, e, exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to complete {service} MCP OAuth: {e!s}", + ) from e + + +# --------------------------------------------------------------------------- +# /reauth — re-authenticate an existing MCP connector +# --------------------------------------------------------------------------- + +@router.get("/auth/mcp/{service}/connector/reauth") +async def reauth_mcp_service( + service: str, + space_id: int, + connector_id: int, + return_url: str | None = None, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +): + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + if not result.scalars().first(): + raise HTTPException( + status_code=404, detail="MCP connector not found or access denied", + ) + + from app.services.mcp_oauth.registry import get_service + + svc = get_service(service) + if not svc: + raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}") + + try: + from app.services.mcp_oauth.discovery import ( + discover_oauth_metadata, + register_client, + ) + + metadata = await discover_oauth_metadata(svc.mcp_url) + auth_endpoint = metadata.get("authorization_endpoint") + token_endpoint = metadata.get("token_endpoint") + registration_endpoint = metadata.get("registration_endpoint") + + if not auth_endpoint or not token_endpoint: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server returned incomplete OAuth metadata.", + ) + + redirect_uri = _build_redirect_uri(service) + + if svc.supports_dcr and registration_endpoint: + dcr = await register_client(registration_endpoint, redirect_uri) + client_id = dcr.get("client_id") + client_secret = dcr.get("client_secret", "") + if not client_id: + raise HTTPException( + status_code=502, + detail=f"DCR for {svc.name} did not return a client_id.", + ) + elif not svc.supports_dcr and svc.client_id_env: + client_id = getattr(config, svc.client_id_env, None) + client_secret = getattr(config, svc.client_secret_env or "", None) or "" + if not client_id: + raise HTTPException( + status_code=500, + detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).", + ) + else: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server has no DCR and no fallback credentials.", + ) + + verifier, challenge = generate_pkce_pair() + enc = _get_token_encryption() + + extra: dict = { + "service": service, + "code_verifier": verifier, + "mcp_client_id": client_id, + "mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "mcp_token_endpoint": token_endpoint, + "mcp_url": svc.mcp_url, + "connector_id": connector_id, + } + if return_url and return_url.startswith("/"): + extra["return_url"] = return_url + + state = _get_state_manager().generate_secure_state( + space_id, user.id, **extra, + ) + + auth_params: dict[str, str] = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": redirect_uri, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + } + if svc.scopes: + auth_params["scope"] = " ".join(svc.scopes) + + auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" + + logger.info( + "Initiating %s MCP re-auth for user %s, connector %s", + svc.name, user.id, connector_id, + ) + return {"auth_url": auth_url} + + except HTTPException: + raise + except Exception as e: + logger.error( + "Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to initiate {service} MCP re-auth: {e!s}", + ) from e + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _invalidate_cache(space_id: int) -> None: + try: + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(space_id) + except Exception: + logger.debug("MCP cache invalidation skipped", exc_info=True) From 81711c9e5b168a9acc4aa5838fe77d3d8260a7ec Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:19:57 +0200 Subject: [PATCH 053/299] wire MCP OAuth route into app router --- surfsense_backend/app/routes/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index ad40666cd..925c207a6 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router +from .mcp_oauth_route import router as mcp_oauth_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router @@ -95,6 +96,7 @@ router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(notifications_router) # Notifications with Zero sync +router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp router.include_router(composio_router) # Composio OAuth and toolkit management router.include_router(public_chat_router) # Public chat sharing and cloning router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages From 9b78fbfe15c36c02e1ed0b958519958d4f93c555 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:20:12 +0200 Subject: [PATCH 054/299] add automatic token refresh for MCP OAuth connectors --- .../app/agents/new_chat/tools/mcp_tool.py | 124 +++++++++++++++++- 1 file changed, 121 insertions(+), 3 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 9743d049d..cf3e51166 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -377,6 +377,118 @@ async def _load_http_mcp_tools( return tools +_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry + + +async def _maybe_refresh_mcp_oauth_token( + session: AsyncSession, + connector: "SearchSourceConnector", + cfg: dict[str, Any], + server_config: dict[str, Any], +) -> dict[str, Any]: + """Refresh the access token for an MCP OAuth connector if it is about to expire. + + Returns the (possibly updated) ``server_config``. + """ + from datetime import UTC, datetime, timedelta + + mcp_oauth = cfg.get("mcp_oauth", {}) + expires_at_str = mcp_oauth.get("expires_at") + if not expires_at_str: + return server_config + + try: + expires_at = datetime.fromisoformat(expires_at_str) + if expires_at.tzinfo is None: + from datetime import timezone + expires_at = expires_at.replace(tzinfo=timezone.utc) + + if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS): + return server_config + except (ValueError, TypeError): + return server_config + + refresh_token = mcp_oauth.get("refresh_token") + if not refresh_token: + logger.warning( + "MCP connector %s token expired but no refresh_token available", + connector.id, + ) + return server_config + + try: + from app.config import config as app_config + from app.services.mcp_oauth.discovery import refresh_access_token + from app.utils.oauth_security import TokenEncryption + + enc = TokenEncryption(app_config.SECRET_KEY) + decrypted_refresh = enc.decrypt_token(refresh_token) + decrypted_secret = ( + enc.decrypt_token(mcp_oauth["client_secret"]) + if mcp_oauth.get("client_secret") + else "" + ) + + token_json = await refresh_access_token( + token_endpoint=mcp_oauth["token_endpoint"], + refresh_token=decrypted_refresh, + client_id=mcp_oauth["client_id"], + client_secret=decrypted_secret, + ) + + new_access = token_json.get("access_token") + if not new_access: + logger.warning( + "MCP connector %s token refresh returned no access_token", + connector.id, + ) + return server_config + + new_expires_at = None + if token_json.get("expires_in"): + new_expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + updated_oauth = dict(mcp_oauth) + updated_oauth["access_token"] = enc.encrypt_token(new_access) + if token_json.get("refresh_token"): + updated_oauth["refresh_token"] = enc.encrypt_token( + token_json["refresh_token"] + ) + updated_oauth["expires_at"] = ( + new_expires_at.isoformat() if new_expires_at else None + ) + + updated_server_config = dict(server_config) + updated_server_config["headers"] = { + **server_config.get("headers", {}), + "Authorization": f"Bearer {new_access}", + } + + from sqlalchemy.orm.attributes import flag_modified + + connector.config = { + **cfg, + "server_config": updated_server_config, + "mcp_oauth": updated_oauth, + } + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + logger.info("Refreshed MCP OAuth token for connector %s", connector.id) + return updated_server_config + + except Exception: + logger.warning( + "Failed to refresh MCP OAuth token for connector %s", + connector.id, + exc_info=True, + ) + return server_config + + def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: """Invalidate cached MCP tools. @@ -429,9 +541,9 @@ async def load_mcp_tools( tools: list[StructuredTool] = [] for connector in result.scalars(): try: - config = connector.config or {} - server_config = config.get("server_config", {}) - trusted_tools = config.get("trusted_tools", []) + cfg = connector.config or {} + server_config = cfg.get("server_config", {}) + trusted_tools = cfg.get("trusted_tools", []) if not server_config or not isinstance(server_config, dict): logger.warning( @@ -439,6 +551,12 @@ async def load_mcp_tools( ) continue + # Refresh OAuth token for MCP OAuth connectors before connecting + if cfg.get("mcp_oauth"): + server_config = await _maybe_refresh_mcp_oauth_token( + session, connector, cfg, server_config, + ) + transport = server_config.get("transport", "stdio") if transport in ("streamable-http", "http", "sse"): From c414cc257f392f84a82da6100e46701bf630404b Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:20:54 +0200 Subject: [PATCH 055/299] add frontend tiles for Linear, Jira, ClickUp MCP connectors --- .../constants/connector-constants.ts | 25 +++++++++++++++++++ .../tabs/all-connectors-tab.tsx | 23 ++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 5b61e8bdf..5ce94809a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -105,6 +105,31 @@ export const OAUTH_CONNECTORS = [ }, ] as const; +// MCP OAuth Connectors (one-click connect via official MCP servers) +export const MCP_OAUTH_CONNECTORS = [ + { + id: "linear-mcp-connector", + title: "Linear (MCP)", + description: "Interact with Linear issues via MCP", + connectorType: EnumConnectorName.MCP_CONNECTOR, + authEndpoint: "/api/v1/auth/mcp/linear/connector/add/", + }, + { + id: "jira-mcp-connector", + title: "Jira (MCP)", + description: "Interact with Jira issues via MCP", + connectorType: EnumConnectorName.MCP_CONNECTOR, + authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", + }, + { + id: "clickup-mcp-connector", + title: "ClickUp (MCP)", + description: "Interact with ClickUp tasks via MCP", + connectorType: EnumConnectorName.MCP_CONNECTOR, + authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/", + }, +] as const; + // Content Sources (tools that extract and import content from external sources) export const CRAWLERS = [ { diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx index 814959ec4..d4f5e2fc1 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx @@ -10,12 +10,14 @@ import { ConnectorCard } from "../components/connector-card"; import { COMPOSIO_CONNECTORS, CRAWLERS, + MCP_OAUTH_CONNECTORS, OAUTH_CONNECTORS, OTHER_CONNECTORS, } from "../constants/connector-constants"; import { getDocumentCountForConnector } from "../utils/connector-document-mapping"; type OAuthConnector = (typeof OAUTH_CONNECTORS)[number]; +type MCPOAuthConnector = (typeof MCP_OAUTH_CONNECTORS)[number]; type ComposioConnector = (typeof COMPOSIO_CONNECTORS)[number]; type OtherConnector = (typeof OTHER_CONNECTORS)[number]; type CrawlerConnector = (typeof CRAWLERS)[number]; @@ -128,6 +130,10 @@ export const AllConnectorsTab: FC = ({ (c) => c.connectorType === EnumConnectorName.AIRTABLE_CONNECTOR ); + const filteredMCPOAuth = MCP_OAUTH_CONNECTORS.filter( + (c) => matchesSearch(c.title, c.description), + ); + const moreIntegrationsComposio = filteredComposio.filter( (c) => !DOCUMENT_FILE_CONNECTOR_TYPES.has(c.connectorType) && @@ -279,6 +285,7 @@ export const AllConnectorsTab: FC = ({ nativeGoogleDriveConnectors.length > 0 || composioGoogleDriveConnectors.length > 0 || fileStorageConnectors.length > 0; + const hasMCPOAuth = filteredMCPOAuth.length > 0; const hasMoreIntegrations = otherDocumentYouTubeConnectors.length > 0 || otherDocumentNotionConnectors.length > 0 || @@ -288,7 +295,7 @@ export const AllConnectorsTab: FC = ({ moreIntegrationsOther.length > 0 || moreIntegrationsCrawlers.length > 0; - const hasAnyResults = hasDocumentFileConnectors || hasMoreIntegrations; + const hasAnyResults = hasDocumentFileConnectors || hasMCPOAuth || hasMoreIntegrations; if (!hasAnyResults && searchQuery) { return ( @@ -318,6 +325,20 @@ export const AllConnectorsTab: FC = ({ )} + {/* Live MCP Integrations */} + {hasMCPOAuth && ( +
+
+

+ Live MCP Integrations +

+
+
+ {filteredMCPOAuth.map((connector) => renderOAuthCard(connector as OAuthConnector | ComposioConnector))} +
+
+ )} + {/* More Integrations */} {hasMoreIntegrations && (
From 7c2d34283b90b5567403d183137d69ff26622c2d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 01:02:57 +0530 Subject: [PATCH 056/299] chore: bump version to 0.1.1-beta.1 in manifest and versions files for obsidian plugin --- manifest.json | 2 +- surfsense_obsidian/manifest.json | 2 +- surfsense_obsidian/versions.json | 2 +- versions.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/manifest.json b/manifest.json index f266e72b5..6578c0ab0 100644 --- a/manifest.json +++ b/manifest.json @@ -1,7 +1,7 @@ { "id": "surfsense", "name": "SurfSense", - "version": "0.1.0", + "version": "0.1.1-beta.1", "minAppVersion": "1.5.4", "description": "Turn your vault into a searchable second brain with SurfSense.", "author": "SurfSense", diff --git a/surfsense_obsidian/manifest.json b/surfsense_obsidian/manifest.json index f266e72b5..6578c0ab0 100644 --- a/surfsense_obsidian/manifest.json +++ b/surfsense_obsidian/manifest.json @@ -1,7 +1,7 @@ { "id": "surfsense", "name": "SurfSense", - "version": "0.1.0", + "version": "0.1.1-beta.1", "minAppVersion": "1.5.4", "description": "Turn your vault into a searchable second brain with SurfSense.", "author": "SurfSense", diff --git a/surfsense_obsidian/versions.json b/surfsense_obsidian/versions.json index 9a3c3429d..c44e23ca6 100644 --- a/surfsense_obsidian/versions.json +++ b/surfsense_obsidian/versions.json @@ -1,3 +1,3 @@ { - "0.1.0": "1.5.4" + "0.1.1-beta.1": "1.5.4" } diff --git a/versions.json b/versions.json index 9a3c3429d..c44e23ca6 100644 --- a/versions.json +++ b/versions.json @@ -1,3 +1,3 @@ { - "0.1.0": "1.5.4" + "0.1.1-beta.1": "1.5.4" } From 8b8c9b1f5dd8b8c88e0d351c91adbc1fda5030a0 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:38:24 +0200 Subject: [PATCH 057/299] add Slack and Airtable MCP OAuth support --- .../app/routes/mcp_oauth_route.py | 8 ++++++-- .../app/services/mcp_oauth/discovery.py | 18 ++++++++++++++---- .../app/services/mcp_oauth/registry.py | 13 +++++++++++++ .../constants/connector-constants.ts | 14 ++++++++++++++ 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index 689914ee8..e47dc0a62 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -106,7 +106,9 @@ async def connect_mcp_service( register_client, ) - metadata = await discover_oauth_metadata(svc.mcp_url) + metadata = await discover_oauth_metadata( + svc.mcp_url, origin_override=svc.oauth_discovery_origin, + ) auth_endpoint = metadata.get("authorization_endpoint") token_endpoint = metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") @@ -409,7 +411,9 @@ async def reauth_mcp_service( register_client, ) - metadata = await discover_oauth_metadata(svc.mcp_url) + metadata = await discover_oauth_metadata( + svc.mcp_url, origin_override=svc.oauth_discovery_origin, + ) auth_endpoint = metadata.get("authorization_endpoint") token_endpoint = metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") diff --git a/surfsense_backend/app/services/mcp_oauth/discovery.py b/surfsense_backend/app/services/mcp_oauth/discovery.py index e8bcd7076..b0f3fef2a 100644 --- a/surfsense_backend/app/services/mcp_oauth/discovery.py +++ b/surfsense_backend/app/services/mcp_oauth/discovery.py @@ -11,14 +11,24 @@ import httpx logger = logging.getLogger(__name__) -async def discover_oauth_metadata(mcp_url: str, *, timeout: float = 15.0) -> dict: +async def discover_oauth_metadata( + mcp_url: str, + *, + origin_override: str | None = None, + timeout: float = 15.0, +) -> dict: """Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint. Per the MCP spec the discovery document lives at the *origin* of the - MCP server URL, not at the MCP endpoint path. + MCP server URL. ``origin_override`` can be used when the OAuth server + lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``, + OAuth at ``airtable.com``). """ - parsed = urlparse(mcp_url) - origin = f"{parsed.scheme}://{parsed.netloc}" + if origin_override: + origin = origin_override.rstrip("/") + else: + parsed = urlparse(mcp_url) + origin = f"{parsed.scheme}://{parsed.netloc}" discovery_url = f"{origin}/.well-known/oauth-authorization-server" async with httpx.AsyncClient(follow_redirects=True) as client: diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 93d5d5448..3f9a03fbc 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -16,6 +16,7 @@ class MCPServiceConfig: name: str mcp_url: str supports_dcr: bool = True + oauth_discovery_origin: str | None = None client_id_env: str | None = None client_secret_env: str | None = None scopes: list[str] = field(default_factory=list) @@ -34,6 +35,18 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { name="ClickUp", mcp_url="https://mcp.clickup.com/mcp", ), + "slack": MCPServiceConfig( + name="Slack", + mcp_url="https://mcp.slack.com/mcp", + supports_dcr=False, + client_id_env="SLACK_CLIENT_ID", + client_secret_env="SLACK_CLIENT_SECRET", + ), + "airtable": MCPServiceConfig( + name="Airtable", + mcp_url="https://mcp.airtable.com/mcp", + oauth_discovery_origin="https://airtable.com", + ), } diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 5ce94809a..dcd63f525 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -128,6 +128,20 @@ export const MCP_OAUTH_CONNECTORS = [ connectorType: EnumConnectorName.MCP_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/", }, + { + id: "slack-mcp-connector", + title: "Slack (MCP)", + description: "Interact with Slack channels via MCP", + connectorType: EnumConnectorName.MCP_CONNECTOR, + authEndpoint: "/api/v1/auth/mcp/slack/connector/add/", + }, + { + id: "airtable-mcp-connector", + title: "Airtable (MCP)", + description: "Interact with Airtable bases via MCP", + connectorType: EnumConnectorName.MCP_CONNECTOR, + authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/", + }, ] as const; // Content Sources (tools that extract and import content from external sources) From 5ff0ec5d5de7ab9d880cb9e6911ecebdf54fed14 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:51:40 +0200 Subject: [PATCH 058/299] disable periodic indexing for live connectors --- .../celery_tasks/schedule_checker_task.py | 22 +++---------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index e6890b0a8..3aee5a4ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -51,43 +51,27 @@ async def _check_and_trigger_schedules(): logger.info(f"Found {len(due_connectors)} connectors due for indexing") - # Import all indexing tasks + # Import indexing tasks for KB connectors only. + # Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord, + # Teams, Gmail, Calendar, Luma) use real-time tools instead. from app.tasks.celery_tasks.connector_tasks import ( - index_airtable_records_task, - index_clickup_tasks_task, index_confluence_pages_task, index_crawled_urls_task, - index_discord_messages_task, index_elasticsearch_documents_task, index_github_repos_task, index_google_calendar_events_task, index_google_drive_files_task, index_google_gmail_messages_task, - index_jira_issues_task, - index_linear_issues_task, - index_luma_events_task, index_notion_pages_task, - index_slack_messages_task, ) - # Map connector types to their tasks task_map = { - SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task, SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task, SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task, - SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task, - SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task, SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task, - SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, - SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task, - SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task, - SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task, SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task, SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task, SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task, - # Composio connector types (unified with native Google tasks) SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task, SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task, SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, From 328219e46fdc9c0b88d194c1e47b5dbc9d4b5d91 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:52:17 +0200 Subject: [PATCH 059/299] disable first-run indexing for live connectors --- .../app/utils/periodic_scheduler.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/surfsense_backend/app/utils/periodic_scheduler.py b/surfsense_backend/app/utils/periodic_scheduler.py index 9ea45df63..923f969d5 100644 --- a/surfsense_backend/app/utils/periodic_scheduler.py +++ b/surfsense_backend/app/utils/periodic_scheduler.py @@ -18,19 +18,9 @@ logger = logging.getLogger(__name__) # Mapping of connector types to their corresponding Celery task names CONNECTOR_TASK_MAP = { - SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages", - SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages", SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages", SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos", - SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues", - SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues", SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages", - SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks", - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events", - SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records", - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages", - SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages", - SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events", SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents", SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls", SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages", @@ -84,40 +74,20 @@ def create_periodic_schedule( f"(frequency: {frequency_minutes} minutes). Triggering first run..." ) - # Import all indexing tasks from app.tasks.celery_tasks.connector_tasks import ( - index_airtable_records_task, index_bookstack_pages_task, - index_clickup_tasks_task, index_confluence_pages_task, index_crawled_urls_task, - index_discord_messages_task, index_elasticsearch_documents_task, index_github_repos_task, - index_google_calendar_events_task, - index_google_gmail_messages_task, - index_jira_issues_task, - index_linear_issues_task, - index_luma_events_task, index_notion_pages_task, index_obsidian_vault_task, - index_slack_messages_task, ) - # Map connector type to task task_map = { - SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task, SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task, SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task, - SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task, - SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task, SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task, - SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, - SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task, - SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task, - SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task, SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task, SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task, SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task, From 53a173a8fdc78a35889ceb028a5e102a11a7ecb8 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:52:23 +0200 Subject: [PATCH 060/299] guard manual indexing for live connectors --- .../routes/search_source_connectors_routes.py | 175 +++--------------- 1 file changed, 28 insertions(+), 147 deletions(-) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index b87ce28c9..7ce3ca9a3 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -693,27 +693,10 @@ async def index_connector_content( user: User = Depends(current_active_user), ): """ - Index content from a connector to a search space. - Requires CONNECTORS_UPDATE permission (to trigger indexing). + Index content from a KB connector to a search space. - Currently supports: - - SLACK_CONNECTOR: Indexes messages from all accessible Slack channels - - TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels - - NOTION_CONNECTOR: Indexes pages from all accessible Notion pages - - GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories - - LINEAR_CONNECTOR: Indexes issues and comments from Linear - - JIRA_CONNECTOR: Indexes issues and comments from Jira - - DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels - - LUMA_CONNECTOR: Indexes events from Luma - - ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch - - WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites - - Args: - connector_id: ID of the connector to use - search_space_id: ID of the search space to store indexed content - - Returns: - Dictionary with indexing status + Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable, + Gmail, Discord, Luma) use real-time agent tools instead. """ try: # Get the connector first @@ -770,9 +753,7 @@ async def index_connector_content( # For calendar connectors, default to today but allow future dates if explicitly provided if connector.connector_type in [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.LUMA_CONNECTOR, ]: # Default to today if no end_date provided (users can manually select future dates) indexing_to = today_str if end_date is None else end_date @@ -796,33 +777,32 @@ async def index_connector_content( # For non-calendar connectors, cap at today indexing_to = end_date if end_date else today_str - if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_slack_messages_task, - ) + _LIVE_CONNECTOR_TYPES = { + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnectorType.LINEAR_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CLICKUP_CONNECTOR, + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.AIRTABLE_CONNECTOR, + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnectorType.LUMA_CONNECTOR, + } + if connector.connector_type in _LIVE_CONNECTOR_TYPES: + return { + "message": ( + f"{connector.connector_type.value} uses real-time agent tools; " + "background indexing is disabled." + ), + "indexing_started": False, + "connector_id": connector_id, + "search_space_id": search_space_id, + "indexing_from": indexing_from, + "indexing_to": indexing_to, + } - logger.info( - f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_slack_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Slack indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_teams_messages_task, - ) - - logger.info( - f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_teams_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Teams indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: + if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task logger.info( @@ -844,28 +824,6 @@ async def index_connector_content( ) response_message = "GitHub indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task - - logger.info( - f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_linear_issues_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Linear indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task - - logger.info( - f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_jira_issues_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Jira indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR: from app.tasks.celery_tasks.connector_tasks import ( index_confluence_pages_task, @@ -892,59 +850,6 @@ async def index_connector_content( ) response_message = "BookStack indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task - - logger.info( - f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_clickup_tasks_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "ClickUp indexing started in the background." - - elif ( - connector.connector_type - == SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR - ): - from app.tasks.celery_tasks.connector_tasks import ( - index_google_calendar_events_task, - ) - - logger.info( - f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_google_calendar_events_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Google Calendar indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_airtable_records_task, - ) - - logger.info( - f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_airtable_records_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Airtable indexing started in the background." - elif ( - connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR - ): - from app.tasks.celery_tasks.connector_tasks import ( - index_google_gmail_messages_task, - ) - - logger.info( - f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_google_gmail_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Google Gmail indexing started in the background." - elif ( connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR ): @@ -1089,30 +994,6 @@ async def index_connector_content( ) response_message = "Dropbox indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_discord_messages_task, - ) - - logger.info( - f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_discord_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Discord indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_luma_events_task - - logger.info( - f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_luma_events_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Luma indexing started in the background." - elif ( connector.connector_type == SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR From 0ab7d6a5e385d071befec0c386121181288b0228 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:52:43 +0200 Subject: [PATCH 061/299] set is_indexable=False for all live connector add routes --- surfsense_backend/app/routes/airtable_add_connector_route.py | 2 +- surfsense_backend/app/routes/clickup_add_connector_route.py | 4 ++-- surfsense_backend/app/routes/discord_add_connector_route.py | 2 +- .../app/routes/google_calendar_add_connector_route.py | 2 +- .../app/routes/google_gmail_add_connector_route.py | 2 +- surfsense_backend/app/routes/jira_add_connector_route.py | 2 +- surfsense_backend/app/routes/linear_add_connector_route.py | 2 +- surfsense_backend/app/routes/luma_add_connector_route.py | 4 ++-- surfsense_backend/app/routes/slack_add_connector_route.py | 2 +- surfsense_backend/app/routes/teams_add_connector_route.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index 1e0b1eb5d..f70b9166b 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -311,7 +311,7 @@ async def airtable_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR, - is_indexable=True, + is_indexable=False, config=credentials_dict, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/clickup_add_connector_route.py b/surfsense_backend/app/routes/clickup_add_connector_route.py index 2cd63eca2..f7b0876e5 100644 --- a/surfsense_backend/app/routes/clickup_add_connector_route.py +++ b/surfsense_backend/app/routes/clickup_add_connector_route.py @@ -301,7 +301,7 @@ async def clickup_callback( # Update existing connector existing_connector.config = connector_config existing_connector.name = "ClickUp Connector" - existing_connector.is_indexable = True + existing_connector.is_indexable = False logger.info( f"Updated existing ClickUp connector for user {user_id} in space {space_id}" ) @@ -310,7 +310,7 @@ async def clickup_callback( new_connector = SearchSourceConnector( name="ClickUp Connector", connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/discord_add_connector_route.py b/surfsense_backend/app/routes/discord_add_connector_route.py index 27bfffc90..4ab48f544 100644 --- a/surfsense_backend/app/routes/discord_add_connector_route.py +++ b/surfsense_backend/app/routes/discord_add_connector_route.py @@ -326,7 +326,7 @@ async def discord_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index d7ccf62ca..a143fd50d 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -340,7 +340,7 @@ async def calendar_callback( config=creds_dict, search_space_id=space_id, user_id=user_id, - is_indexable=True, + is_indexable=False, ) session.add(db_connector) await session.commit() diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index dd8feb1c7..9b807a556 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -371,7 +371,7 @@ async def gmail_callback( config=creds_dict, search_space_id=space_id, user_id=user_id, - is_indexable=True, + is_indexable=False, ) session.add(db_connector) await session.commit() diff --git a/surfsense_backend/app/routes/jira_add_connector_route.py b/surfsense_backend/app/routes/jira_add_connector_route.py index 6cd6283d7..eeb4f91d9 100644 --- a/surfsense_backend/app/routes/jira_add_connector_route.py +++ b/surfsense_backend/app/routes/jira_add_connector_route.py @@ -386,7 +386,7 @@ async def jira_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.JIRA_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py index 9345ae495..f59c17d25 100644 --- a/surfsense_backend/app/routes/linear_add_connector_route.py +++ b/surfsense_backend/app/routes/linear_add_connector_route.py @@ -399,7 +399,7 @@ async def linear_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/luma_add_connector_route.py b/surfsense_backend/app/routes/luma_add_connector_route.py index 04d840a08..7040581bc 100644 --- a/surfsense_backend/app/routes/luma_add_connector_route.py +++ b/surfsense_backend/app/routes/luma_add_connector_route.py @@ -61,7 +61,7 @@ async def add_luma_connector( if existing_connector: # Update existing connector with new API key existing_connector.config = {"api_key": request.api_key} - existing_connector.is_indexable = True + existing_connector.is_indexable = False await session.commit() await session.refresh(existing_connector) @@ -82,7 +82,7 @@ async def add_luma_connector( config={"api_key": request.api_key}, search_space_id=request.space_id, user_id=user.id, - is_indexable=True, + is_indexable=False, ) session.add(db_connector) diff --git a/surfsense_backend/app/routes/slack_add_connector_route.py b/surfsense_backend/app/routes/slack_add_connector_route.py index 405ab2c4f..f6a1458a0 100644 --- a/surfsense_backend/app/routes/slack_add_connector_route.py +++ b/surfsense_backend/app/routes/slack_add_connector_route.py @@ -312,7 +312,7 @@ async def slack_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.SLACK_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/teams_add_connector_route.py b/surfsense_backend/app/routes/teams_add_connector_route.py index bbaae3a5f..9d0f5144f 100644 --- a/surfsense_backend/app/routes/teams_add_connector_route.py +++ b/surfsense_backend/app/routes/teams_add_connector_route.py @@ -321,7 +321,7 @@ async def teams_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, From e676ebfabeb0584cee14232eb90575646dd8b040 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 21 Apr 2026 21:52:54 +0200 Subject: [PATCH 062/299] remove live connectors from AUTO_INDEX_DEFAULTS --- .../constants/connector-constants.ts | 54 ------------------- 1 file changed, 54 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index dcd63f525..39e827d1a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -297,66 +297,18 @@ export interface AutoIndexConfig { } export const AUTO_INDEX_DEFAULTS: Record = { - [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of emails.", - }, [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: { daysBack: 30, daysForward: 0, frequencyMinutes: 1440, syncDescription: "Syncing your last 30 days of emails.", }, - [EnumConnectorName.SLACK_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of messages.", - }, - [EnumConnectorName.DISCORD_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of messages.", - }, - [EnumConnectorName.TEAMS_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of messages.", - }, - [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: { - daysBack: 90, - daysForward: 90, - frequencyMinutes: 1440, - syncDescription: "Syncing 90 days of past and upcoming events.", - }, [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: { daysBack: 90, daysForward: 90, frequencyMinutes: 1440, syncDescription: "Syncing 90 days of past and upcoming events.", }, - [EnumConnectorName.LINEAR_CONNECTOR]: { - daysBack: 90, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 90 days of issues.", - }, - [EnumConnectorName.JIRA_CONNECTOR]: { - daysBack: 90, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 90 days of issues.", - }, - [EnumConnectorName.CLICKUP_CONNECTOR]: { - daysBack: 90, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 90 days of tasks.", - }, [EnumConnectorName.NOTION_CONNECTOR]: { daysBack: 365, daysForward: 0, @@ -369,12 +321,6 @@ export const AUTO_INDEX_DEFAULTS: Record = { frequencyMinutes: 1440, syncDescription: "Syncing your documentation.", }, - [EnumConnectorName.AIRTABLE_CONNECTOR]: { - daysBack: 365, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your bases.", - }, }; export const AUTO_INDEX_CONNECTOR_TYPES = new Set(Object.keys(AUTO_INDEX_DEFAULTS)); From e86d279d5500aa27319e4e5e037bae66c2ea8511 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 05:34:17 +0530 Subject: [PATCH 063/299] chore: update GitHub Actions workflow to include publish mode selection and improve version resolution logic --- .github/workflows/release-obsidian-plugin.yml | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/.github/workflows/release-obsidian-plugin.yml b/.github/workflows/release-obsidian-plugin.yml index 198b87611..1560f1c41 100644 --- a/.github/workflows/release-obsidian-plugin.yml +++ b/.github/workflows/release-obsidian-plugin.yml @@ -7,10 +7,14 @@ on: - "obsidian-v*" workflow_dispatch: inputs: - tag: - description: "Tag to build (e.g. obsidian-v0.1.0). Dry-run only when run manually." + publish: + description: "Publish to GitHub Releases" required: true - default: "obsidian-v0.0.0-test" + type: choice + options: + - never + - always + default: "never" permissions: contents: write @@ -39,24 +43,35 @@ jobs: - name: Resolve plugin version id: version run: | + manifest_version=$(node -p "require('./manifest.json').version") if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then - tag="${{ github.event.inputs.tag }}" + # Manual runs derive the release version from manifest.json. + version="$manifest_version" + tag="obsidian-v$version" else tag="${GITHUB_REF_NAME}" - fi - if [ -z "$tag" ] || [[ "$tag" != obsidian-v* ]]; then - echo "::error::Invalid tag '$tag'. Expected format: obsidian-v" - exit 1 - fi - version="${tag#obsidian-v}" - manifest_version=$(node -p "require('./manifest.json').version") - if [ "$version" != "$manifest_version" ]; then - echo "::error::Tag version '$version' does not match manifest version '$manifest_version'" - exit 1 + if [ -z "$tag" ] || [[ "$tag" != obsidian-v* ]]; then + echo "::error::Invalid tag '$tag'. Expected format: obsidian-v" + exit 1 + fi + version="${tag#obsidian-v}" + if [ "$version" != "$manifest_version" ]; then + echo "::error::Tag version '$version' does not match manifest version '$manifest_version'" + exit 1 + fi fi echo "tag=$tag" >> "$GITHUB_OUTPUT" echo "version=$version" >> "$GITHUB_OUTPUT" + - name: Resolve publish mode + id: release_mode + run: | + if [ "${{ github.event_name }}" = "push" ] || [ "${{ inputs.publish }}" = "always" ]; then + echo "should_publish=true" >> "$GITHUB_OUTPUT" + else + echo "should_publish=false" >> "$GITHUB_OUTPUT" + fi + - run: npm ci - run: npm run lint @@ -70,7 +85,7 @@ jobs: done - name: Mirror manifest.json + versions.json to repo root - if: github.event_name == 'push' + if: steps.release_mode.outputs.should_publish == 'true' working-directory: ${{ github.workspace }} run: | cp surfsense_obsidian/manifest.json manifest.json @@ -90,7 +105,7 @@ jobs: # Publish release under bare `manifest.json` version (no `obsidian-v` prefix) for BRAT/store compatibility. - name: Create GitHub release - if: github.event_name == 'push' + if: steps.release_mode.outputs.should_publish == 'true' uses: softprops/action-gh-release@v3 with: tag_name: ${{ steps.version.outputs.version }} From 9ecccc5403ba19bb1c96c1ab9cdd617b34485295 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 05:44:03 +0530 Subject: [PATCH 064/299] feat: implement dynamic API proxying to FastAPI backend - Added a new route handler for dynamic API requests, allowing proxying to the FastAPI backend. - Removed the previous rewrite configuration in next.config.ts for cleaner integration. - Updated .env.example to clarify backend URL usage. --- surfsense_web/.env.example | 1 - surfsense_web/app/api/v1/[...path]/route.ts | 65 +++++++++++++++++++++ surfsense_web/next.config.ts | 15 ----- 3 files changed, 65 insertions(+), 16 deletions(-) create mode 100644 surfsense_web/app/api/v1/[...path]/route.ts diff --git a/surfsense_web/.env.example b/surfsense_web/.env.example index 9b54edc13..b121daf0b 100644 --- a/surfsense_web/.env.example +++ b/surfsense_web/.env.example @@ -1,7 +1,6 @@ NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 # Server-only. Internal backend URL used by Next.js server code. -# Falls back to NEXT_PUBLIC_FASTAPI_BACKEND_URL when unset. FASTAPI_BACKEND_INTERNAL_URL=https://your-internal-backend.example.com NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE diff --git a/surfsense_web/app/api/v1/[...path]/route.ts b/surfsense_web/app/api/v1/[...path]/route.ts new file mode 100644 index 000000000..82c8e2a5d --- /dev/null +++ b/surfsense_web/app/api/v1/[...path]/route.ts @@ -0,0 +1,65 @@ +import type { NextRequest } from "next/server"; + +export const dynamic = "force-dynamic"; + +const HOP_BY_HOP_HEADERS = new Set([ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", +]); + +function getBackendBaseUrl() { + const base = process.env.FASTAPI_BACKEND_INTERNAL_URL || "http://localhost:8000"; + return base.endsWith("/") ? base.slice(0, -1) : base; +} + +function toUpstreamHeaders(headers: Headers) { + const nextHeaders = new Headers(headers); + nextHeaders.delete("host"); + nextHeaders.delete("content-length"); + return nextHeaders; +} + +function toClientHeaders(headers: Headers) { + const nextHeaders = new Headers(headers); + for (const header of HOP_BY_HOP_HEADERS) { + nextHeaders.delete(header); + } + return nextHeaders; +} + +async function proxy( + request: NextRequest, + context: { params: Promise<{ path?: string[] }> } +) { + const params = await context.params; + const path = params.path?.join("/") || ""; + const upstreamUrl = new URL(`${getBackendBaseUrl()}/api/v1/${path}`); + upstreamUrl.search = request.nextUrl.search; + + const hasBody = request.method !== "GET" && request.method !== "HEAD"; + + const response = await fetch(upstreamUrl, { + method: request.method, + headers: toUpstreamHeaders(request.headers), + body: hasBody ? request.body : undefined, + // `duplex: "half"` is required by the Fetch spec when streaming a + // ReadableStream as the request body. Avoids buffering uploads in heap. + // @ts-expect-error - `duplex` is not yet in lib.dom RequestInit types. + duplex: hasBody ? "half" : undefined, + redirect: "manual", + }); + + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers: toClientHeaders(response.headers), + }); +} + +export { proxy as GET, proxy as POST, proxy as PUT, proxy as PATCH, proxy as DELETE, proxy as OPTIONS, proxy as HEAD }; diff --git a/surfsense_web/next.config.ts b/surfsense_web/next.config.ts index 6aed14d95..5414d548d 100644 --- a/surfsense_web/next.config.ts +++ b/surfsense_web/next.config.ts @@ -44,21 +44,6 @@ const nextConfig: NextConfig = { }, }, - // Proxy /api/v1/* to the FastAPI backend. Keeps the real backend host - // out of the client bundle. FASTAPI_BACKEND_INTERNAL_URL is server-only. - async rewrites() { - const target = - process.env.FASTAPI_BACKEND_INTERNAL_URL || - process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || - "http://localhost:8000"; - return [ - { - source: "/api/v1/:path*", - destination: `${target.replace(/\/+$/, "")}/api/v1/:path*`, - }, - ]; - }, - // Configure webpack (SVGR) webpack: (config) => { // SVGR: import *.svg as React components From ae264290d040defd8fcd447e9e7f31a8ea14a57f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:07:38 +0530 Subject: [PATCH 065/299] feat: update Obsidian connector UI and improve user instructions --- .../components/obsidian-connect-form.tsx | 239 ++++++++---------- .../components/obsidian-config.tsx | 8 +- .../views/connector-connect-view.tsx | 4 +- 3 files changed, 117 insertions(+), 134 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx index 49c68ba39..689684c51 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx @@ -1,6 +1,6 @@ "use client"; -import { Check, Copy, Download, Info, KeyRound, Settings2 } from "lucide-react"; +import { Check, Copy, Info } from "lucide-react"; import { type FC, useCallback, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; @@ -55,145 +55,126 @@ export const ObsidianConnectForm: FC = ({ onBack }) => { that just closes the dialog (see component-level docstring). */}
- + Plugin-based sync SurfSense now syncs Obsidian via an official plugin that runs inside Obsidian itself. Works on desktop and mobile, in cloud and self-hosted - deployments — no server-side vault mounts required. + deployments. - {/* Step 1 — Install plugin */}
-
-
- 1 -
-

Install the plugin

-
-

- Grab the latest SurfSense plugin release. Once it's in the community - store, you'll also be able to install it from{" "} - Settings → Community plugins{" "} - inside Obsidian. -

- - - -
- - {/* Step 2 — Copy API key */} -
-
-
- 2 -
-

- Copy your API key -

- -
-

- Paste this into the plugin's API token{" "} - setting. The token expires after 24 hours; long-lived personal access - tokens are coming in a future release. -

- - {isLoading ? ( -
- ) : apiKey ? ( -
-
-

- {apiKey} -

-
- -
- ) : ( -

- No API key available — try refreshing the page. -

- )} -
- - {/* Step 3 — Server URL */} -
-
-
- 3 -
-

- Point the plugin at this server -

-
-

- Paste this URL into the plugin's Server URL{" "} - setting. We auto-detect it from your current dashboard origin. -

-
-
-

- {BACKEND_URL} +

+ {/* Step 1 — Install plugin */} +
+
+
+ 1 +
+

Install the plugin

+
+

+ Grab the latest SurfSense plugin release. Once it's in the community + store, you'll also be able to install it from{" "} + Settings → Community plugins{" "} + inside Obsidian.

-
- -
-
+ + + + - {/* Step 4 — Pick search space */} -
-
-
- 4 -
-

- Pick this search space -

- -
-

- In the plugin's Search space{" "} - setting, choose the search space you want this vault to sync into. - The connector will appear here automatically once the plugin makes - its first sync. -

+
+ + {/* Step 2 — Copy API key */} +
+
+
+ 2 +
+

Copy your API key

+
+

+ Paste this into the plugin's API token{" "} + setting. The token expires after 24 hours. Long-lived personal access + tokens are coming in a future release. +

+ + {isLoading ? ( +
+ ) : apiKey ? ( +
+
+

+ {apiKey} +

+
+ +
+ ) : ( +

+ No API key available — try refreshing the page. +

+ )} +
+ +
+ + {/* Step 3 — Server URL */} +
+
+
+ 3 +
+

Point the plugin at this server

+
+

+ For SurfSense Cloud, use the default surfsense.com. + If you are self-hosting, set the plugin's{" "} + Server URL to your frontend domain. +

+
+ +
+ + {/* Step 4 — Pick search space */} +
+
+
+ 4 +
+

Pick this search space

+
+

+ In the plugin's Search space{" "} + setting, choose the search space you want this vault to sync into. + The connector will appear here automatically once the plugin makes + its first sync. +

+
+
{getConnectorBenefits(EnumConnectorName.OBSIDIAN_CONNECTOR) && ( diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx index 33a7110c0..a9b98b76c 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx @@ -130,20 +130,20 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => { Plugin connected - Edits in Obsidian sync over HTTPS. To stop syncing, disable or uninstall the plugin in + Your notes stay synced automatically. To stop syncing, disable or uninstall the plugin in Obsidian, or delete this connector.
-
+

Vault status

{tileRows.map((stat) => (
-
+
{stat.label}
{stat.value}
diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx index e58542923..5b82a8e88 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-connect-view.tsx @@ -151,7 +151,9 @@ export const ConnectorConnectView: FC = ({ {connectorType === "MCP_CONNECTOR" ? "Connect" - : `Connect ${getConnectorTypeDisplay(connectorType)}`} + : connectorType === "OBSIDIAN_CONNECTOR" + ? "Done" + : `Connect ${getConnectorTypeDisplay(connectorType)}`} {isSubmitting && } From a5e5f229d9a0f623e007f92a1235ef2026fd40f4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:13:01 +0530 Subject: [PATCH 066/299] feat: add connection status indicator to settings UI --- surfsense_obsidian/src/settings.ts | 50 +++++++++++++++++++++++++++++- surfsense_obsidian/styles.css | 33 ++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/surfsense_obsidian/src/settings.ts b/surfsense_obsidian/src/settings.ts index 1191e5b7a..4dc6a732f 100644 --- a/surfsense_obsidian/src/settings.ts +++ b/surfsense_obsidian/src/settings.ts @@ -4,6 +4,7 @@ import { Platform, PluginSettingTab, Setting, + setIcon, } from "obsidian"; import { AuthError } from "./api-client"; import { normalizeFolder, parseExcludePatterns } from "./excludes"; @@ -29,7 +30,7 @@ export class SurfSenseSettingTab extends PluginSettingTab { const settings = this.plugin.settings; - new Setting(containerEl).setName("Connection").setHeading(); + this.renderConnectionHeading(containerEl); new Setting(containerEl) .setName("Server URL") @@ -262,6 +263,53 @@ export class SurfSenseSettingTab extends PluginSettingTab { ); } + private renderConnectionHeading(containerEl: HTMLElement): void { + const heading = new Setting(containerEl).setName("Connection").setHeading(); + const indicator = heading.nameEl.createSpan({ + cls: "surfsense-connection-indicator", + }); + const visual = this.getConnectionVisual(); + indicator.addClass(`surfsense-connection-indicator--${visual.tone}`); + setIcon(indicator, visual.icon); + indicator.setAttr("aria-label", visual.label); + indicator.setAttr("title", visual.label); + } + + private getConnectionVisual(): { + icon: string; + label: string; + tone: "ok" | "syncing" | "warn" | "err" | "muted"; + } { + const settings = this.plugin.settings; + const kind = this.plugin.lastStatus.kind; + + if (kind === "auth-error") { + return { icon: "lock", label: "Token invalid or expired", tone: "err" }; + } + if (kind === "error") { + return { icon: "alert-circle", label: "Connection error", tone: "err" }; + } + if (kind === "offline") { + return { icon: "wifi-off", label: "Server unreachable", tone: "warn" }; + } + + if (!settings.apiToken) { + return { icon: "circle", label: "Missing API token", tone: "muted" }; + } + if (!settings.searchSpaceId) { + return { icon: "circle", label: "Pick a search space", tone: "muted" }; + } + if (!settings.connectorId) { + return { icon: "circle", label: "Not connected yet", tone: "muted" }; + } + + if (kind === "syncing" || kind === "queued") { + return { icon: "refresh-ccw", label: "Connected and syncing", tone: "syncing" }; + } + + return { icon: "check-circle", label: "Connected", tone: "ok" }; + } + private async refreshSearchSpaces(): Promise { this.loadingSpaces = true; try { diff --git a/surfsense_obsidian/styles.css b/surfsense_obsidian/styles.css index 81b2203f3..586ddffa6 100644 --- a/surfsense_obsidian/styles.css +++ b/surfsense_obsidian/styles.css @@ -37,3 +37,36 @@ .surfsense-status--err .surfsense-status__icon { color: var(--color-red); } + +.surfsense-connection-indicator { + display: inline-flex; + margin-left: 8px; + vertical-align: middle; + width: 14px; + height: 14px; +} + +.surfsense-connection-indicator svg { + width: 14px; + height: 14px; +} + +.surfsense-connection-indicator--ok { + color: var(--color-green); +} + +.surfsense-connection-indicator--syncing { + color: var(--color-blue); +} + +.surfsense-connection-indicator--warn { + color: var(--color-yellow); +} + +.surfsense-connection-indicator--err { + color: var(--color-red); +} + +.surfsense-connection-indicator--muted { + color: var(--text-muted); +} From 26ed2a2ba1041c88aa293daed0510e99b9ac96fd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:16:38 +0530 Subject: [PATCH 067/299] feat: improve connection handling and status updates in SurfSense plugin --- surfsense_obsidian/src/settings.ts | 2 ++ surfsense_obsidian/src/sync-engine.ts | 2 ++ 2 files changed, 4 insertions(+) diff --git a/surfsense_obsidian/src/settings.ts b/surfsense_obsidian/src/settings.ts index 4dc6a732f..cc72da9c1 100644 --- a/surfsense_obsidian/src/settings.ts +++ b/surfsense_obsidian/src/settings.ts @@ -115,7 +115,9 @@ export class SurfSenseSettingTab extends PluginSettingTab { if (this.plugin.settings.searchSpaceId !== null) { try { await this.plugin.engine.ensureConnected(); + await this.plugin.engine.flushQueue(); new Notice("Surfsense: vault connected."); + this.display(); } catch (err) { this.handleApiError(err); } diff --git a/surfsense_obsidian/src/sync-engine.ts b/surfsense_obsidian/src/sync-engine.ts index d6f7fa91c..4ffd2a651 100644 --- a/surfsense_obsidian/src/sync-engine.ts +++ b/surfsense_obsidian/src/sync-engine.ts @@ -126,6 +126,7 @@ export class SyncEngine { this.setStatus("idle", "Pick a search space in settings."); return; } + this.setStatus("syncing", "Connecting to SurfSense"); try { const fingerprint = await computeVaultFingerprint(this.deps.app); const resp = await this.deps.apiClient.connect({ @@ -139,6 +140,7 @@ export class SyncEngine { s.vaultId = resp.vault_id; s.connectorId = resp.connector_id; }); + this.setStatus(this.queueStatusKind(), this.statusDetail()); } catch (err) { this.handleStartupError(err); } From 3b38daaca59a72a3a57396ddf887dd0e794926ce Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:18:58 +0530 Subject: [PATCH 068/299] chore: update version from 0.1.1-beta.1 to 0.1.1 in manifest and versions files for SurfSense plugin --- manifest.json | 2 +- surfsense_obsidian/manifest.json | 2 +- surfsense_obsidian/versions.json | 2 +- versions.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/manifest.json b/manifest.json index 6578c0ab0..dee7a58db 100644 --- a/manifest.json +++ b/manifest.json @@ -1,7 +1,7 @@ { "id": "surfsense", "name": "SurfSense", - "version": "0.1.1-beta.1", + "version": "0.1.1", "minAppVersion": "1.5.4", "description": "Turn your vault into a searchable second brain with SurfSense.", "author": "SurfSense", diff --git a/surfsense_obsidian/manifest.json b/surfsense_obsidian/manifest.json index 6578c0ab0..dee7a58db 100644 --- a/surfsense_obsidian/manifest.json +++ b/surfsense_obsidian/manifest.json @@ -1,7 +1,7 @@ { "id": "surfsense", "name": "SurfSense", - "version": "0.1.1-beta.1", + "version": "0.1.1", "minAppVersion": "1.5.4", "description": "Turn your vault into a searchable second brain with SurfSense.", "author": "SurfSense", diff --git a/surfsense_obsidian/versions.json b/surfsense_obsidian/versions.json index c44e23ca6..b190f0f61 100644 --- a/surfsense_obsidian/versions.json +++ b/surfsense_obsidian/versions.json @@ -1,3 +1,3 @@ { - "0.1.1-beta.1": "1.5.4" + "0.1.1": "1.5.4" } diff --git a/versions.json b/versions.json index c44e23ca6..b190f0f61 100644 --- a/versions.json +++ b/versions.json @@ -1,3 +1,3 @@ { - "0.1.1-beta.1": "1.5.4" + "0.1.1": "1.5.4" } From 3b7f27cff9f7aac320950d0558b2e3cca7c8d7cd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:26:49 +0530 Subject: [PATCH 069/299] chore: update GitHub Actions workflows to use Node.js 22.x and enhance connection indicator styling in SurfSense plugin --- .github/workflows/obsidian-plugin-lint.yml | 7 +------ .github/workflows/release-obsidian-plugin.yml | 2 +- surfsense_obsidian/src/settings.ts | 3 ++- surfsense_obsidian/styles.css | 8 ++++++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/obsidian-plugin-lint.yml b/.github/workflows/obsidian-plugin-lint.yml index 80a49c3f7..42bd099b1 100644 --- a/.github/workflows/obsidian-plugin-lint.yml +++ b/.github/workflows/obsidian-plugin-lint.yml @@ -25,17 +25,12 @@ jobs: defaults: run: working-directory: surfsense_obsidian - strategy: - fail-fast: false - matrix: - node-version: [20.x, 22.x] - steps: - uses: actions/checkout@v6 - uses: actions/setup-node@v6 with: - node-version: ${{ matrix.node-version }} + node-version: 22.x cache: npm cache-dependency-path: surfsense_obsidian/package-lock.json diff --git a/.github/workflows/release-obsidian-plugin.yml b/.github/workflows/release-obsidian-plugin.yml index 1560f1c41..68cb0ad1b 100644 --- a/.github/workflows/release-obsidian-plugin.yml +++ b/.github/workflows/release-obsidian-plugin.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-node@v6 with: - node-version: 20.x + node-version: 22.x cache: npm cache-dependency-path: surfsense_obsidian/package-lock.json diff --git a/surfsense_obsidian/src/settings.ts b/surfsense_obsidian/src/settings.ts index cc72da9c1..8efea62fe 100644 --- a/surfsense_obsidian/src/settings.ts +++ b/surfsense_obsidian/src/settings.ts @@ -115,7 +115,7 @@ export class SurfSenseSettingTab extends PluginSettingTab { if (this.plugin.settings.searchSpaceId !== null) { try { await this.plugin.engine.ensureConnected(); - await this.plugin.engine.flushQueue(); + await this.plugin.engine.maybeReconcile(true); new Notice("Surfsense: vault connected."); this.display(); } catch (err) { @@ -267,6 +267,7 @@ export class SurfSenseSettingTab extends PluginSettingTab { private renderConnectionHeading(containerEl: HTMLElement): void { const heading = new Setting(containerEl).setName("Connection").setHeading(); + heading.nameEl.addClass("surfsense-connection-heading"); const indicator = heading.nameEl.createSpan({ cls: "surfsense-connection-indicator", }); diff --git a/surfsense_obsidian/styles.css b/surfsense_obsidian/styles.css index 586ddffa6..3d1ec6ab8 100644 --- a/surfsense_obsidian/styles.css +++ b/surfsense_obsidian/styles.css @@ -40,12 +40,16 @@ .surfsense-connection-indicator { display: inline-flex; - margin-left: 8px; - vertical-align: middle; width: 14px; height: 14px; } +.surfsense-connection-heading { + display: inline-flex; + align-items: center; + gap: 8px; +} + .surfsense-connection-indicator svg { width: 14px; height: 14px; From 4a75603d4f9b0d4875e7aba999ae8e8baebf5670 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:38:51 +0530 Subject: [PATCH 070/299] feat: implement sync notifications for Obsidian plugin - Added functionality to create and update notifications during the Obsidian sync process. - Improved handling of sync completion and failure notifications. - Updated connector naming convention in various locations for consistency. --- .../app/routes/obsidian_plugin_routes.py | 128 +++++++++++++++++- .../test_obsidian_plugin_routes.py | 8 +- .../constants/connector-constants.ts | 2 +- .../content/docs/connectors/obsidian.mdx | 4 +- 4 files changed, 134 insertions(+), 8 deletions(-) diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index 08e0f7d50..096058d8a 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -41,6 +41,7 @@ from app.schemas.obsidian_plugin import ( SyncAckItem, SyncBatchRequest, ) +from app.services.notification_service import NotificationService from app.services.obsidian_plugin_indexer import ( delete_note, get_manifest, @@ -68,6 +69,103 @@ def _build_handshake() -> dict[str, object]: return {"capabilities": list(OBSIDIAN_CAPABILITIES)} +def _connector_type_value(connector: SearchSourceConnector) -> str: + connector_type = connector.connector_type + if hasattr(connector_type, "value"): + return str(connector_type.value) + return str(connector_type) + + +async def _start_obsidian_sync_notification( + session: AsyncSession, + *, + user: User, + connector: SearchSourceConnector, + total_count: int, +): + """Create/update the rolling inbox item for Obsidian plugin sync. + + Obsidian sync is continuous and batched, so we keep one stable + operation_id per connector instead of creating a new notification per batch. + """ + handler = NotificationService.connector_indexing + operation_id = f"obsidian_sync_connector_{connector.id}" + connector_name = connector.name or "Obsidian" + notification = await handler.find_or_create_notification( + session=session, + user_id=user.id, + operation_id=operation_id, + title=f"Syncing: {connector_name}", + message="Syncing from Obsidian plugin", + search_space_id=connector.search_space_id, + initial_metadata={ + "connector_id": connector.id, + "connector_name": connector_name, + "connector_type": _connector_type_value(connector), + "sync_stage": "processing", + "indexed_count": 0, + "failed_count": 0, + "total_count": total_count, + "source": "obsidian_plugin", + }, + ) + return await handler.update_notification( + session=session, + notification=notification, + status="in_progress", + metadata_updates={ + "sync_stage": "processing", + "total_count": total_count, + }, + ) + + +async def _finish_obsidian_sync_notification( + session: AsyncSession, + *, + notification, + indexed: int, + failed: int, +): + """Mark the rolling Obsidian sync inbox item complete or failed.""" + handler = NotificationService.connector_indexing + connector_name = notification.notification_metadata.get("connector_name", "Obsidian") + if failed > 0 and indexed == 0: + title = f"Failed: {connector_name}" + message = ( + f"Sync failed: {failed} file(s) failed" + if failed > 1 + else "Sync failed: 1 file failed" + ) + status_value = "failed" + stage = "failed" + else: + title = f"Ready: {connector_name}" + if failed > 0: + message = f"Partially synced: {indexed} file(s) synced, {failed} failed." + elif indexed == 0: + message = "Already up to date!" + elif indexed == 1: + message = "Now searchable! 1 file synced." + else: + message = f"Now searchable! {indexed} files synced." + status_value = "completed" + stage = "completed" + + await handler.update_notification( + session=session, + notification=notification, + title=title, + message=message, + status=status_value, + metadata_updates={ + "indexed_count": indexed, + "failed_count": failed, + "sync_stage": stage, + }, + ) + + async def _resolve_vault_connector( session: AsyncSession, *, @@ -188,7 +286,7 @@ def _build_config( def _display_name(vault_name: str) -> str: - return f"Obsidian \u2014 {vault_name}" + return f"Obsidian - {vault_name}" @router.post("/connect", response_model=ConnectResponse) @@ -335,6 +433,18 @@ async def obsidian_sync( connector = await _resolve_vault_connector( session, user=user, vault_id=payload.vault_id ) + notification = None + try: + notification = await _start_obsidian_sync_notification( + session, user=user, connector=connector, total_count=len(payload.notes) + ) + except Exception: + logger.warning( + "obsidian sync notification start failed connector=%s user=%s", + connector.id, + user.id, + exc_info=True, + ) items: list[SyncAckItem] = [] indexed = 0 @@ -362,6 +472,22 @@ async def obsidian_sync( SyncAckItem(path=note.path, status="error", error=str(exc)[:300]) ) + if notification is not None: + try: + await _finish_obsidian_sync_notification( + session, + notification=notification, + indexed=indexed, + failed=failed, + ) + except Exception: + logger.warning( + "obsidian sync notification finish failed connector=%s user=%s", + connector.id, + user.id, + exc_info=True, + ) + return SyncAck( vault_id=payload.vault_id, indexed=indexed, diff --git a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py index 0ddb9d713..449e1473d 100644 --- a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py +++ b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py @@ -183,7 +183,7 @@ class TestConnectRace: async with AsyncSession(async_engine) as s: s.add( SearchSourceConnector( - name="Obsidian \u2014 First", + name="Obsidian - First", connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR, is_indexable=False, config={ @@ -202,7 +202,7 @@ class TestConnectRace: async with AsyncSession(async_engine) as s: s.add( SearchSourceConnector( - name="Obsidian \u2014 Second", + name="Obsidian - Second", connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR, is_indexable=False, config={ @@ -228,7 +228,7 @@ class TestConnectRace: async with AsyncSession(async_engine) as s: s.add( SearchSourceConnector( - name="Obsidian \u2014 Desktop", + name="Obsidian - Desktop", connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR, is_indexable=False, config={ @@ -247,7 +247,7 @@ class TestConnectRace: async with AsyncSession(async_engine) as s: s.add( SearchSourceConnector( - name="Obsidian \u2014 Mobile", + name="Obsidian - Mobile", connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR, is_indexable=False, config={ diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index c8d63f309..c897489ff 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -180,7 +180,7 @@ export const OTHER_CONNECTORS = [ { id: "obsidian-connector", title: "Obsidian", - description: "Sync your Obsidian vault on desktop or mobile via the SurfSense plugin", + description: "Sync your Obsidian vault on desktop or mobile", connectorType: EnumConnectorName.OBSIDIAN_CONNECTOR, }, ] as const; diff --git a/surfsense_web/content/docs/connectors/obsidian.mdx b/surfsense_web/content/docs/connectors/obsidian.mdx index 1efa4ff8f..5f939e277 100644 --- a/surfsense_web/content/docs/connectors/obsidian.mdx +++ b/surfsense_web/content/docs/connectors/obsidian.mdx @@ -30,7 +30,7 @@ This works for cloud and self-hosted deployments, including desktop and mobile c 4. Paste your SurfSense API token from the user settings section. 5. Paste your Server URL in the plugin setting: either your SurfSense main domain (if `/api/v1` rewrites are enabled) or your direct backend URL. 6. Choose the Search Space in the plugin, then the first sync should run automatically. -7. Confirm the connector appears as **Obsidian — <vault>** in SurfSense. +7. Confirm the connector appears as **Obsidian - <vault>** in SurfSense. ## Migrating from the legacy connector @@ -38,7 +38,7 @@ If you previously used the legacy Obsidian connector architecture, migrate to th 1. Delete the old legacy Obsidian connector from SurfSense. 2. Install and configure the SurfSense Obsidian plugin using the quick start above. -3. Run the first plugin sync and verify the new **Obsidian — <vault>** connector is active. +3. Run the first plugin sync and verify the new **Obsidian - <vault>** connector is active. Deleting the legacy connector also deletes all documents that were indexed by that connector. Always finish and verify plugin sync before deleting the old connector. From 3eb4d55ef51b956216b04733ee6b330d4176777d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:40:39 +0530 Subject: [PATCH 071/299] chore: ran linting --- .../129_obsidian_plugin_vault_identity.py | 3 +- .../app/routes/obsidian_plugin_routes.py | 24 +- .../app/schemas/obsidian_plugin.py | 12 +- .../app/services/obsidian_plugin_indexer.py | 8 +- .../test_obsidian_plugin_routes.py | 12 +- .../tests/unit/test_error_contract.py | 4 +- surfsense_web/app/api/v1/[...path]/route.ts | 15 +- .../components/ApiKeyContent.tsx | 2 +- .../components/DesktopContent.tsx | 7 +- .../components/PurchaseHistoryContent.tsx | 4 +- .../components/obsidian-connect-form.tsx | 55 ++-- .../components/obsidian-config.tsx | 13 +- .../constants/connector-constants.ts | 80 +++--- .../hooks/use-connector-dialog.ts | 12 +- .../components/free-chat/free-chat-page.tsx | 3 +- .../homepage/features-bento-grid.tsx | 262 ++++++++++++++++-- .../components/sources/DocumentUploadTab.tsx | 54 ++-- 17 files changed, 369 insertions(+), 201 deletions(-) diff --git a/surfsense_backend/alembic/versions/129_obsidian_plugin_vault_identity.py b/surfsense_backend/alembic/versions/129_obsidian_plugin_vault_identity.py index e716dfff1..0c0e3dbe5 100644 --- a/surfsense_backend/alembic/versions/129_obsidian_plugin_vault_identity.py +++ b/surfsense_backend/alembic/versions/129_obsidian_plugin_vault_identity.py @@ -91,8 +91,7 @@ def downgrade() -> None: ) conn.execute( sa.text( - "DROP INDEX IF EXISTS " - "search_source_connectors_obsidian_plugin_vault_uniq" + "DROP INDEX IF EXISTS search_source_connectors_obsidian_plugin_vault_uniq" ) ) conn.execute( diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index 096058d8a..8069f8265 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -129,7 +129,9 @@ async def _finish_obsidian_sync_notification( ): """Mark the rolling Obsidian sync inbox item complete or failed.""" handler = NotificationService.connector_indexing - connector_name = notification.notification_metadata.get("connector_name", "Obsidian") + connector_name = notification.notification_metadata.get( + "connector_name", "Obsidian" + ) if failed > 0 and indexed == 0: title = f"Failed: {connector_name}" message = ( @@ -273,9 +275,7 @@ async def _find_by_fingerprint( return (await session.execute(stmt)).scalars().first() -def _build_config( - payload: ConnectRequest, *, now_iso: str -) -> dict[str, object]: +def _build_config(payload: ConnectRequest, *, now_iso: str) -> dict[str, object]: return { "vault_id": payload.vault_id, "vault_name": payload.vault_name, @@ -456,9 +456,7 @@ async def obsidian_sync( session, connector=connector, payload=note, user_id=str(user.id) ) indexed += 1 - items.append( - SyncAckItem(path=note.path, status="ok", document_id=doc.id) - ) + items.append(SyncAckItem(path=note.path, status="ok", document_id=doc.id)) except HTTPException: raise except Exception as exc: @@ -597,9 +595,7 @@ async def obsidian_delete_notes( path, payload.vault_id, ) - items.append( - DeleteAckItem(path=path, status="error", error=str(exc)[:300]) - ) + items.append(DeleteAckItem(path=path, status="error", error=str(exc)[:300])) return DeleteAck( vault_id=payload.vault_id, @@ -616,9 +612,7 @@ async def obsidian_manifest( session: AsyncSession = Depends(get_async_session), ) -> ManifestResponse: """Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff.""" - connector = await _resolve_vault_connector( - session, user=user, vault_id=vault_id - ) + connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) return await get_manifest(session, connector=connector, vault_id=vault_id) @@ -633,9 +627,7 @@ async def obsidian_stats( ``files_synced`` excludes tombstones so it matches ``/manifest``; ``last_sync_at`` includes them so deletes advance the freshness signal. """ - connector = await _resolve_vault_connector( - session, user=user, vault_id=vault_id - ) + connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) is_active = Document.document_metadata["deleted_at"].as_string().is_(None) diff --git a/surfsense_backend/app/schemas/obsidian_plugin.py b/surfsense_backend/app/schemas/obsidian_plugin.py index fac44bc3d..745886ef6 100644 --- a/surfsense_backend/app/schemas/obsidian_plugin.py +++ b/surfsense_backend/app/schemas/obsidian_plugin.py @@ -24,10 +24,14 @@ class _PluginBase(BaseModel): class NotePayload(_PluginBase): """One Obsidian note as pushed by the plugin (the source of truth).""" - vault_id: str = Field(..., description="Stable plugin-generated UUID for this vault") + vault_id: str = Field( + ..., description="Stable plugin-generated UUID for this vault" + ) path: str = Field(..., description="Vault-relative path, e.g. 'notes/foo.md'") name: str = Field(..., description="File stem (no extension)") - extension: str = Field(default="md", description="File extension without leading dot") + extension: str = Field( + default="md", description="File extension without leading dot" + ) content: str = Field(default="", description="Raw markdown body (post-frontmatter)") frontmatter: dict[str, Any] = Field(default_factory=dict) @@ -38,7 +42,9 @@ class NotePayload(_PluginBase): embeds: list[str] = Field(default_factory=list) aliases: list[str] = Field(default_factory=list) - content_hash: str = Field(..., description="Plugin-computed SHA-256 of the raw content") + content_hash: str = Field( + ..., description="Plugin-computed SHA-256 of the raw content" + ) size: int | None = Field( default=None, ge=0, diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py index ea62f16d8..5afdbf886 100644 --- a/surfsense_backend/app/services/obsidian_plugin_indexer.py +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -126,9 +126,7 @@ def _build_document_string(payload: NotePayload, vault_name: str) -> str: existing search relevance heuristics keep working unchanged. """ tags_line = ", ".join(payload.tags) if payload.tags else "None" - links_line = ( - ", ".join(payload.resolved_links) if payload.resolved_links else "None" - ) + links_line = ", ".join(payload.resolved_links) if payload.resolved_links else "None" return ( "\n" f"Title: {payload.name}\n" @@ -235,9 +233,7 @@ async def upsert_note( if not prepared: if existing is not None: return existing - raise RuntimeError( - f"Indexing pipeline rejected obsidian note {payload.path}" - ) + raise RuntimeError(f"Indexing pipeline rejected obsidian note {payload.path}") document = prepared[0] diff --git a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py index 449e1473d..1dd7e2a23 100644 --- a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py +++ b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py @@ -111,9 +111,7 @@ async def race_user_and_space(async_engine): # connectors test creates documents, so we wipe them too. The # CASCADE on user_id catches anything we missed. await cleanup.execute( - text( - 'DELETE FROM search_source_connectors WHERE user_id = :uid' - ), + text("DELETE FROM search_source_connectors WHERE user_id = :uid"), {"uid": user_id}, ) await cleanup.execute( @@ -156,9 +154,7 @@ class TestConnectRace: ) await obsidian_connect(payload, user=fresh_user, session=s) - results = await asyncio.gather( - _call("a"), _call("b"), return_exceptions=True - ) + results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True) for r in results: assert not isinstance(r, Exception), f"Connect raised: {r!r}" @@ -430,9 +426,7 @@ class TestWireContractSmoke: assert {it.status for it in rename_resp.items} == {"ok", "missing"} # snake_case fields are deliberate — the plugin decoder maps them # to camelCase explicitly. - assert all( - it.old_path and it.new_path for it in rename_resp.items - ) + assert all(it.old_path and it.new_path for it in rename_resp.items) # 4. /notes DELETE async def _delete(*args, **kwargs) -> bool: diff --git a/surfsense_backend/tests/unit/test_error_contract.py b/surfsense_backend/tests/unit/test_error_contract.py index 81ec08b2d..ec8021290 100644 --- a/surfsense_backend/tests/unit/test_error_contract.py +++ b/surfsense_backend/tests/unit/test_error_contract.py @@ -202,9 +202,7 @@ class TestHTTPExceptionHandler: # Intentional 503s (e.g. feature flag off) must surface the developer # message so the frontend can render actionable copy. body = _assert_envelope(client.get("/http-503"), 503) - assert ( - body["error"]["message"] == "Page purchases are temporarily unavailable." - ) + assert body["error"]["message"] == "Page purchases are temporarily unavailable." assert body["error"]["message"] != GENERIC_5XX_MESSAGE def test_502_preserves_detail(self, client): diff --git a/surfsense_web/app/api/v1/[...path]/route.ts b/surfsense_web/app/api/v1/[...path]/route.ts index 82c8e2a5d..418bf1a33 100644 --- a/surfsense_web/app/api/v1/[...path]/route.ts +++ b/surfsense_web/app/api/v1/[...path]/route.ts @@ -33,10 +33,7 @@ function toClientHeaders(headers: Headers) { return nextHeaders; } -async function proxy( - request: NextRequest, - context: { params: Promise<{ path?: string[] }> } -) { +async function proxy(request: NextRequest, context: { params: Promise<{ path?: string[] }> }) { const params = await context.params; const path = params.path?.join("/") || ""; const upstreamUrl = new URL(`${getBackendBaseUrl()}/api/v1/${path}`); @@ -62,4 +59,12 @@ async function proxy( }); } -export { proxy as GET, proxy as POST, proxy as PUT, proxy as PATCH, proxy as DELETE, proxy as OPTIONS, proxy as HEAD }; +export { + proxy as GET, + proxy as POST, + proxy as PUT, + proxy as PATCH, + proxy as DELETE, + proxy as OPTIONS, + proxy as HEAD, +}; diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ApiKeyContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ApiKeyContent.tsx index 3600d30db..c34d9c0ca 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ApiKeyContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/ApiKeyContent.tsx @@ -3,7 +3,7 @@ import { Check, Copy, Info } from "lucide-react"; import { useTranslations } from "next-intl"; import { useCallback, useRef, useState } from "react"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { useApiKey } from "@/hooks/use-api-key"; diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx index 3175268d2..63ca9f5df 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx @@ -200,8 +200,8 @@ export function DesktopContent() { Launch on Startup - Automatically start SurfSense when you sign in to your computer so global - shortcuts and folder sync are always available. + Automatically start SurfSense when you sign in to your computer so global shortcuts and + folder sync are always available. @@ -232,8 +232,7 @@ export function DesktopContent() { Start minimized to tray

- Skip the main window on boot — SurfSense lives in the system tray until you need - it. + Skip the main window on boot — SurfSense lives in the system tray until you need it.

new Date(b.created_at).getTime() - new Date(a.created_at).getTime() - ); + ].sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime()); }, [pagesQuery.data, tokensQuery.data]); if (isLoading) { diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx index 689684c51..ecbb09fae 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx @@ -4,17 +4,16 @@ import { Check, Copy, Info } from "lucide-react"; import { type FC, useCallback, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; +import { EnumConnectorName } from "@/contracts/enums/connector"; import { useApiKey } from "@/hooks/use-api-key"; import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils"; -import { EnumConnectorName } from "@/contracts/enums/connector"; import { getConnectorBenefits } from "../connector-benefits"; import type { ConnectFormProps } from "../index"; const PLUGIN_RELEASES_URL = "https://github.com/MODSetter/SurfSense/releases?q=obsidian&expanded=true"; -const BACKEND_URL = - process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ?? "https://surfsense.com"; +const BACKEND_URL = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ?? "https://surfsense.com"; /** * Obsidian connect form for the plugin-only architecture. @@ -32,9 +31,7 @@ const BACKEND_URL = export const ObsidianConnectForm: FC = ({ onBack }) => { const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); const [copiedUrl, setCopiedUrl] = useState(false); - const urlCopyTimerRef = useRef | undefined>( - undefined - ); + const urlCopyTimerRef = useRef | undefined>(undefined); const copyServerUrl = useCallback(async () => { const ok = await copyToClipboardUtil(BACKEND_URL); @@ -59,9 +56,8 @@ export const ObsidianConnectForm: FC = ({ onBack }) => { Plugin-based sync - SurfSense now syncs Obsidian via an official plugin that runs inside - Obsidian itself. Works on desktop and mobile, in cloud and self-hosted - deployments. + SurfSense now syncs Obsidian via an official plugin that runs inside Obsidian itself. + Works on desktop and mobile, in cloud and self-hosted deployments. @@ -76,10 +72,9 @@ export const ObsidianConnectForm: FC = ({ onBack }) => {

Install the plugin

- Grab the latest SurfSense plugin release. Once it's in the community - store, you'll also be able to install it from{" "} - Settings → Community plugins{" "} - inside Obsidian. + Grab the latest SurfSense plugin release. Once it's in the community store, you'll + also be able to install it from{" "} + Settings → Community plugins inside Obsidian.

= ({ onBack }) => { rel="noopener noreferrer" className="inline-flex" > - @@ -104,9 +104,9 @@ export const ObsidianConnectForm: FC = ({ onBack }) => {

Copy your API key

- Paste this into the plugin's API token{" "} - setting. The token expires after 24 hours. Long-lived personal access - tokens are coming in a future release. + Paste this into the plugin's API token setting. + The token expires after 24 hours. Long-lived personal access tokens are coming in a + future release.

{isLoading ? ( @@ -151,9 +151,9 @@ export const ObsidianConnectForm: FC = ({ onBack }) => {

Point the plugin at this server

- For SurfSense Cloud, use the default surfsense.com. - If you are self-hosting, set the plugin's{" "} - Server URL to your frontend domain. + For SurfSense Cloud, use the default{" "} + surfsense.com. If you are self-hosting, set the + plugin's Server URL to your frontend domain.

@@ -168,10 +168,9 @@ export const ObsidianConnectForm: FC = ({ onBack }) => {

Pick this search space

- In the plugin's Search space{" "} - setting, choose the search space you want this vault to sync into. - The connector will appear here automatically once the plugin makes - its first sync. + In the plugin's Search space setting, choose the + search space you want this vault to sync into. The connector will appear here + automatically once the plugin makes its first sync.

@@ -183,11 +182,9 @@ export const ObsidianConnectForm: FC = ({ onBack }) => { What you get with Obsidian integration:
    - {getConnectorBenefits(EnumConnectorName.OBSIDIAN_CONNECTOR)?.map( - (benefit) => ( -
  • {benefit}
  • - ) - )} + {getConnectorBenefits(EnumConnectorName.OBSIDIAN_CONNECTOR)?.map((benefit) => ( +
  • {benefit}
  • + ))}
)} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx index a9b98b76c..52b18fa09 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx @@ -117,9 +117,7 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => { label: "Files synced", value: placeholder ?? - (typeof stats?.files_synced === "number" - ? stats.files_synced.toLocaleString() - : "—"), + (typeof stats?.files_synced === "number" ? stats.files_synced.toLocaleString() : "—"), }, ]; }, [config.vault_name, stats, statsError]); @@ -139,10 +137,7 @@ const PluginStats: FC<{ config: Record }> = ({ config }) => {

Vault status

{tileRows.map((stat) => ( -
+
{stat.label}
@@ -160,8 +155,8 @@ const UnknownConnectorState: FC = () => ( Unrecognized config - This connector has neither plugin metadata nor a legacy marker. It may predate migration — - you can safely delete it and re-install the SurfSense Obsidian plugin to resume syncing. + This connector has neither plugin metadata nor a legacy marker. It may predate migration — you + can safely delete it and re-install the SurfSense Obsidian plugin to resume syncing. ); diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index c897489ff..154ff247a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -349,12 +349,7 @@ export const AUTO_INDEX_CONNECTOR_TYPES = new Set(Object.keys(AUTO_INDEX // `lib/posthog/events.ts` or per-connector tracking code. // ============================================================================ -export type ConnectorTelemetryGroup = - | "oauth" - | "composio" - | "crawler" - | "other" - | "unknown"; +export type ConnectorTelemetryGroup = "oauth" | "composio" | "crawler" | "other" | "unknown"; export interface ConnectorTelemetryMeta { connector_type: string; @@ -363,45 +358,44 @@ export interface ConnectorTelemetryMeta { is_oauth: boolean; } -const CONNECTOR_TELEMETRY_REGISTRY: ReadonlyMap = - (() => { - const map = new Map(); +const CONNECTOR_TELEMETRY_REGISTRY: ReadonlyMap = (() => { + const map = new Map(); - for (const c of OAUTH_CONNECTORS) { - map.set(c.connectorType, { - connector_type: c.connectorType, - connector_title: c.title, - connector_group: "oauth", - is_oauth: true, - }); - } - for (const c of COMPOSIO_CONNECTORS) { - map.set(c.connectorType, { - connector_type: c.connectorType, - connector_title: c.title, - connector_group: "composio", - is_oauth: true, - }); - } - for (const c of CRAWLERS) { - map.set(c.connectorType, { - connector_type: c.connectorType, - connector_title: c.title, - connector_group: "crawler", - is_oauth: false, - }); - } - for (const c of OTHER_CONNECTORS) { - map.set(c.connectorType, { - connector_type: c.connectorType, - connector_title: c.title, - connector_group: "other", - is_oauth: false, - }); - } + for (const c of OAUTH_CONNECTORS) { + map.set(c.connectorType, { + connector_type: c.connectorType, + connector_title: c.title, + connector_group: "oauth", + is_oauth: true, + }); + } + for (const c of COMPOSIO_CONNECTORS) { + map.set(c.connectorType, { + connector_type: c.connectorType, + connector_title: c.title, + connector_group: "composio", + is_oauth: true, + }); + } + for (const c of CRAWLERS) { + map.set(c.connectorType, { + connector_type: c.connectorType, + connector_title: c.title, + connector_group: "crawler", + is_oauth: false, + }); + } + for (const c of OTHER_CONNECTORS) { + map.set(c.connectorType, { + connector_type: c.connectorType, + connector_title: c.title, + connector_group: "other", + is_oauth: false, + }); + } - return map; - })(); + return map; +})(); /** * Returns telemetry metadata for a connector_type, or a minimal "unknown" diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index e00a69939..317973eba 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -350,11 +350,7 @@ export const useConnectorDialog = () => { // Set connecting state immediately to disable button and show spinner setConnectingId(connector.id); - trackConnectorSetupStarted( - Number(searchSpaceId), - connector.connectorType, - "oauth_click" - ); + trackConnectorSetupStarted(Number(searchSpaceId), connector.connectorType, "oauth_click"); try { // Check if authEndpoint already has query parameters @@ -478,11 +474,7 @@ export const useConnectorDialog = () => { (connectorType: string) => { if (!searchSpaceId) return; - trackConnectorSetupStarted( - Number(searchSpaceId), - connectorType, - "non_oauth_click" - ); + trackConnectorSetupStarted(Number(searchSpaceId), connectorType, "non_oauth_click"); setConnectingConnectorType(connectorType); }, diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index b389a8489..deac1fd00 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -210,8 +210,7 @@ export function FreeChatPage() { trackAnonymousChatMessageSent({ modelSlug, messageLength: userQuery.trim().length, - hasUploadedDoc: - anonMode.isAnonymous && anonMode.uploadedDoc !== null ? true : false, + hasUploadedDoc: anonMode.isAnonymous && anonMode.uploadedDoc !== null ? true : false, surface: "free_chat_page", }); diff --git a/surfsense_web/components/homepage/features-bento-grid.tsx b/surfsense_web/components/homepage/features-bento-grid.tsx index 835ccd2c2..7406223de 100644 --- a/surfsense_web/components/homepage/features-bento-grid.tsx +++ b/surfsense_web/components/homepage/features-bento-grid.tsx @@ -426,15 +426,50 @@ const AiSortIllustration = () => ( AI File Sorting illustration showing automatic folder organization {/* Scattered documents on the left */} - - - + + + {/* AI sparkle / magic in the center */} - - + + @@ -442,51 +477,208 @@ const AiSortIllustration = () => ( {/* Animated sorting arrows */} - + - + - + - + {/* Organized folder tree on the right */} {/* Root folder */} - - - + + + {/* Subfolder 1 */} - - - - - + + + + + {/* Subfolder 2 */} - - - - - + + + + + {/* Subfolder 3 */} - - - - - + + + + + {/* Sparkle accents */} @@ -495,10 +687,22 @@ const AiSortIllustration = () => ( - + - + diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index 65fa117f7..5a324fea9 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -546,35 +546,35 @@ export function DocumentUploadTab({ ) ) : ( -
{ - if (!isElectron) fileInputRef.current?.click(); - }} - onKeyDown={(e) => { - if (e.key === "Enter" || e.key === " ") { - e.preventDefault(); +
{ if (!isElectron) fileInputRef.current?.click(); - } - }} - > - -
-

- {isElectron ? t("select_files_or_folder") : t("tap_select_files_or_folder")} -

-

{t("file_size_limit")}

-
-
e.stopPropagation()} - onKeyDown={(e) => e.stopPropagation()} + }} + onKeyDown={(e) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + if (!isElectron) fileInputRef.current?.click(); + } + }} > - {renderBrowseButton({ fullWidth: true })} -
-
+ +
+

+ {isElectron ? t("select_files_or_folder") : t("tap_select_files_or_folder")} +

+

{t("file_size_limit")}

+
+
e.stopPropagation()} + onKeyDown={(e) => e.stopPropagation()} + > + {renderBrowseButton({ fullWidth: true })} +
+
)}
From 940889c291fd8ca6fa445efa5bf73fdde1949d0c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 08:42:38 +0200 Subject: [PATCH 072/299] fix open redirect, error leaking, unused imports, state validation --- .../app/agents/new_chat/tools/discord/_auth.py | 4 ---- .../app/agents/new_chat/tools/luma/_auth.py | 4 ---- .../app/agents/new_chat/tools/teams/_auth.py | 6 ------ surfsense_backend/app/routes/__init__.py | 2 +- surfsense_backend/app/routes/mcp_oauth_route.py | 13 ++++++++----- .../app/routes/oauth_connector_base.py | 8 ++++---- 6 files changed, 13 insertions(+), 24 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py index b369c10f1..1f51e3660 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py @@ -1,7 +1,5 @@ """Shared auth helper for Discord agent tools (REST API, not gateway bot).""" -import logging - from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -9,8 +7,6 @@ from app.config import config from app.db import SearchSourceConnector, SearchSourceConnectorType from app.utils.oauth_security import TokenEncryption -logger = logging.getLogger(__name__) - DISCORD_API = "https://discord.com/api/v10" diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py index ef2fa8540..1d88161d6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py @@ -1,14 +1,10 @@ """Shared auth helper for Luma agent tools.""" -import logging - from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.db import SearchSourceConnector, SearchSourceConnectorType -logger = logging.getLogger(__name__) - LUMA_API = "https://public-api.luma.com/v1" diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py index 989fce7c6..f24f5502e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py @@ -1,15 +1,9 @@ """Shared auth helper for Teams agent tools (Microsoft Graph REST API).""" -import logging - from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.config import config from app.db import SearchSourceConnector, SearchSourceConnectorType -from app.utils.oauth_security import TokenEncryption - -logger = logging.getLogger(__name__) GRAPH_API = "https://graph.microsoft.com/v1.0" diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 925c207a6..40ca7a7e8 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -96,7 +96,7 @@ router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(notifications_router) # Notifications with Zero sync -router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp +router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable router.include_router(composio_router) # Composio OAuth and toolkit management router.include_router(public_chat_router) # Public chat sharing and cloning router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index e47dc0a62..0870d52fe 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -182,7 +182,7 @@ async def connect_mcp_service( except Exception as e: logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True) raise HTTPException( - status_code=500, detail=f"Failed to initiate {service} MCP OAuth: {e!s}", + status_code=500, detail=f"Failed to initiate {service} MCP OAuth.", ) from e @@ -221,6 +221,9 @@ async def mcp_oauth_callback( space_id = data["space_id"] svc_key = data.get("service", service) + if svc_key != service: + raise HTTPException(status_code=400, detail="State/path service mismatch") + from app.services.mcp_oauth.registry import get_service svc = get_service(svc_key) @@ -315,7 +318,7 @@ async def mcp_oauth_callback( svc.name, db_connector.id, user_id, ) reauth_return_url = data.get("return_url") - if reauth_return_url and reauth_return_url.startswith("/"): + if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" ) @@ -347,7 +350,7 @@ async def mcp_oauth_callback( except IntegrityError as e: await session.rollback() raise HTTPException( - status_code=409, detail=f"Database integrity error: {e!s}", + status_code=409, detail="A connector for this service already exists.", ) from e _invalidate_cache(space_id) @@ -368,7 +371,7 @@ async def mcp_oauth_callback( ) raise HTTPException( status_code=500, - detail=f"Failed to complete {service} MCP OAuth: {e!s}", + detail=f"Failed to complete {service} MCP OAuth.", ) from e @@ -495,7 +498,7 @@ async def reauth_mcp_service( ) raise HTTPException( status_code=500, - detail=f"Failed to initiate {service} MCP re-auth: {e!s}", + detail=f"Failed to initiate {service} MCP re-auth.", ) from e diff --git a/surfsense_backend/app/routes/oauth_connector_base.py b/surfsense_backend/app/routes/oauth_connector_base.py index 0483d2540..0638e8f34 100644 --- a/surfsense_backend/app/routes/oauth_connector_base.py +++ b/surfsense_backend/app/routes/oauth_connector_base.py @@ -430,7 +430,7 @@ class OAuthConnectorRoute: state_mgr = oauth._get_state_manager() extra: dict[str, Any] = {"connector_id": connector_id} - if return_url and return_url.startswith("/"): + if return_url and return_url.startswith("/") and not return_url.startswith("//"): extra["return_url"] = return_url auth_params: dict[str, str] = { @@ -498,7 +498,7 @@ class OAuthConnectorRoute: data = state_mgr.validate_state(state) except Exception as e: raise HTTPException( - status_code=400, detail=f"Invalid state parameter: {e!s}" + status_code=400, detail="Invalid or expired state parameter." ) from e user_id = UUID(data["user_id"]) @@ -552,7 +552,7 @@ class OAuthConnectorRoute: db_connector.id, user_id, ) - if reauth_return_url and reauth_return_url.startswith("/"): + if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" ) @@ -603,7 +603,7 @@ class OAuthConnectorRoute: except IntegrityError as e: await session.rollback() raise HTTPException( - status_code=409, detail=f"Database integrity error: {e!s}" + status_code=409, detail="A connector for this service already exists." ) from e logger.info( From ea3508cb25db5369dc01c5443fc318830089673f Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 08:57:28 +0200 Subject: [PATCH 073/299] use native connector types for MCP OAuth, restore original UI --- .../app/agents/new_chat/tools/mcp_tool.py | 5 +- .../app/routes/mcp_oauth_route.py | 42 ++++++++--------- .../app/services/mcp_oauth/registry.py | 6 +++ .../constants/connector-constants.ts | 47 ++----------------- .../tabs/all-connectors-tab.tsx | 23 +-------- 5 files changed, 34 insertions(+), 89 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index cf3e51166..47ee16f7d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -530,11 +530,12 @@ async def load_mcp_tools( return list(cached_tools) try: + # Find all connectors with MCP server config: generic MCP_CONNECTOR type + # and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth. result = await session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.config.has_key("server_config"), # noqa: W601 ), ) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index 0870d52fe..f7164eab3 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -56,9 +56,7 @@ def _get_token_encryption() -> TokenEncryption: def _build_redirect_uri(service: str) -> str: - base = config.BACKEND_URL - if not base: - raise HTTPException(status_code=500, detail="BACKEND_URL not configured.") + base = config.BACKEND_URL or "http://localhost:8000" return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback" @@ -288,6 +286,7 @@ async def mcp_oauth_callback( } # ---- Re-auth path ---- + db_connector_type = SearchSourceConnectorType(svc.connector_type) reauth_connector_id = data.get("connector_id") if reauth_connector_id: result = await session.execute( @@ -295,8 +294,7 @@ async def mcp_oauth_callback( SearchSourceConnector.id == reauth_connector_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.search_space_id == space_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.connector_type == db_connector_type, ) ) db_connector = result.scalars().first() @@ -329,15 +327,15 @@ async def mcp_oauth_callback( # ---- New connector path ---- connector_name = await generate_unique_connector_name( session, - SearchSourceConnectorType.MCP_CONNECTOR, + db_connector_type, space_id, user_id, - f"{svc.name} MCP", + svc.name, ) new_connector = SearchSourceConnector( name=connector_name, - connector_type=SearchSourceConnectorType.MCP_CONNECTOR, + connector_type=db_connector_type, is_indexable=False, config=connector_config, search_space_id=space_id, @@ -388,26 +386,26 @@ async def reauth_mcp_service( user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session), ): - result = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == connector_id, - SearchSourceConnector.user_id == user.id, - SearchSourceConnector.search_space_id == space_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, - ) - ) - if not result.scalars().first(): - raise HTTPException( - status_code=404, detail="MCP connector not found or access denied", - ) - from app.services.mcp_oauth.registry import get_service svc = get_service(service) if not svc: raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}") + db_connector_type = SearchSourceConnectorType(svc.connector_type) + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == db_connector_type, + ) + ) + if not result.scalars().first(): + raise HTTPException( + status_code=404, detail="Connector not found or access denied", + ) + try: from app.services.mcp_oauth.discovery import ( discover_oauth_metadata, diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 3f9a03fbc..e6a9d20a5 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field class MCPServiceConfig: name: str mcp_url: str + connector_type: str supports_dcr: bool = True oauth_discovery_origin: str | None = None client_id_env: str | None = None @@ -26,18 +27,22 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { "linear": MCPServiceConfig( name="Linear", mcp_url="https://mcp.linear.app/mcp", + connector_type="LINEAR_CONNECTOR", ), "jira": MCPServiceConfig( name="Jira", mcp_url="https://mcp.atlassian.com/v1/mcp", + connector_type="JIRA_CONNECTOR", ), "clickup": MCPServiceConfig( name="ClickUp", mcp_url="https://mcp.clickup.com/mcp", + connector_type="CLICKUP_CONNECTOR", ), "slack": MCPServiceConfig( name="Slack", mcp_url="https://mcp.slack.com/mcp", + connector_type="SLACK_CONNECTOR", supports_dcr=False, client_id_env="SLACK_CLIENT_ID", client_secret_env="SLACK_CLIENT_SECRET", @@ -45,6 +50,7 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { "airtable": MCPServiceConfig( name="Airtable", mcp_url="https://mcp.airtable.com/mcp", + connector_type="AIRTABLE_CONNECTOR", oauth_discovery_origin="https://airtable.com", ), } diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 39e827d1a..08ffde9ae 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -31,7 +31,7 @@ export const OAUTH_CONNECTORS = [ title: "Airtable", description: "Search your Airtable bases", connectorType: EnumConnectorName.AIRTABLE_CONNECTOR, - authEndpoint: "/api/v1/auth/airtable/connector/add/", + authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/", }, { id: "notion-connector", @@ -45,14 +45,14 @@ export const OAUTH_CONNECTORS = [ title: "Linear", description: "Search issues & projects", connectorType: EnumConnectorName.LINEAR_CONNECTOR, - authEndpoint: "/api/v1/auth/linear/connector/add/", + authEndpoint: "/api/v1/auth/mcp/linear/connector/add/", }, { id: "slack-connector", title: "Slack", description: "Search Slack messages", connectorType: EnumConnectorName.SLACK_CONNECTOR, - authEndpoint: "/api/v1/auth/slack/connector/add/", + authEndpoint: "/api/v1/auth/mcp/slack/connector/add/", }, { id: "teams-connector", @@ -87,7 +87,7 @@ export const OAUTH_CONNECTORS = [ title: "Jira", description: "Search Jira issues", connectorType: EnumConnectorName.JIRA_CONNECTOR, - authEndpoint: "/api/v1/auth/jira/connector/add/", + authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", }, { id: "confluence-connector", @@ -101,47 +101,8 @@ export const OAUTH_CONNECTORS = [ title: "ClickUp", description: "Search ClickUp tasks", connectorType: EnumConnectorName.CLICKUP_CONNECTOR, - authEndpoint: "/api/v1/auth/clickup/connector/add/", - }, -] as const; - -// MCP OAuth Connectors (one-click connect via official MCP servers) -export const MCP_OAUTH_CONNECTORS = [ - { - id: "linear-mcp-connector", - title: "Linear (MCP)", - description: "Interact with Linear issues via MCP", - connectorType: EnumConnectorName.MCP_CONNECTOR, - authEndpoint: "/api/v1/auth/mcp/linear/connector/add/", - }, - { - id: "jira-mcp-connector", - title: "Jira (MCP)", - description: "Interact with Jira issues via MCP", - connectorType: EnumConnectorName.MCP_CONNECTOR, - authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", - }, - { - id: "clickup-mcp-connector", - title: "ClickUp (MCP)", - description: "Interact with ClickUp tasks via MCP", - connectorType: EnumConnectorName.MCP_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/", }, - { - id: "slack-mcp-connector", - title: "Slack (MCP)", - description: "Interact with Slack channels via MCP", - connectorType: EnumConnectorName.MCP_CONNECTOR, - authEndpoint: "/api/v1/auth/mcp/slack/connector/add/", - }, - { - id: "airtable-mcp-connector", - title: "Airtable (MCP)", - description: "Interact with Airtable bases via MCP", - connectorType: EnumConnectorName.MCP_CONNECTOR, - authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/", - }, ] as const; // Content Sources (tools that extract and import content from external sources) diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx index d4f5e2fc1..814959ec4 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx @@ -10,14 +10,12 @@ import { ConnectorCard } from "../components/connector-card"; import { COMPOSIO_CONNECTORS, CRAWLERS, - MCP_OAUTH_CONNECTORS, OAUTH_CONNECTORS, OTHER_CONNECTORS, } from "../constants/connector-constants"; import { getDocumentCountForConnector } from "../utils/connector-document-mapping"; type OAuthConnector = (typeof OAUTH_CONNECTORS)[number]; -type MCPOAuthConnector = (typeof MCP_OAUTH_CONNECTORS)[number]; type ComposioConnector = (typeof COMPOSIO_CONNECTORS)[number]; type OtherConnector = (typeof OTHER_CONNECTORS)[number]; type CrawlerConnector = (typeof CRAWLERS)[number]; @@ -130,10 +128,6 @@ export const AllConnectorsTab: FC = ({ (c) => c.connectorType === EnumConnectorName.AIRTABLE_CONNECTOR ); - const filteredMCPOAuth = MCP_OAUTH_CONNECTORS.filter( - (c) => matchesSearch(c.title, c.description), - ); - const moreIntegrationsComposio = filteredComposio.filter( (c) => !DOCUMENT_FILE_CONNECTOR_TYPES.has(c.connectorType) && @@ -285,7 +279,6 @@ export const AllConnectorsTab: FC = ({ nativeGoogleDriveConnectors.length > 0 || composioGoogleDriveConnectors.length > 0 || fileStorageConnectors.length > 0; - const hasMCPOAuth = filteredMCPOAuth.length > 0; const hasMoreIntegrations = otherDocumentYouTubeConnectors.length > 0 || otherDocumentNotionConnectors.length > 0 || @@ -295,7 +288,7 @@ export const AllConnectorsTab: FC = ({ moreIntegrationsOther.length > 0 || moreIntegrationsCrawlers.length > 0; - const hasAnyResults = hasDocumentFileConnectors || hasMCPOAuth || hasMoreIntegrations; + const hasAnyResults = hasDocumentFileConnectors || hasMoreIntegrations; if (!hasAnyResults && searchQuery) { return ( @@ -325,20 +318,6 @@ export const AllConnectorsTab: FC = ({
)} - {/* Live MCP Integrations */} - {hasMCPOAuth && ( -
-
-

- Live MCP Integrations -

-
-
- {filteredMCPOAuth.map((connector) => renderOAuthCard(connector as OAuthConnector | ComposioConnector))} -
-
- )} - {/* More Integrations */} {hasMoreIntegrations && (
From c277b6c1219bd4794d7f89da72034e0161e2326e Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:01:35 +0200 Subject: [PATCH 074/299] skip indexing config dialog for non-indexable connectors --- .../assistant-ui/connector-popup/hooks/use-connector-dialog.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index caa85ba2d..4a07693ce 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -314,6 +314,9 @@ export const useConnectorDialog = () => { oauthConnector.title, oauthConnector.connectorType ); + } else if (!newConnector.is_indexable) { + toast.success(`${oauthConnector.title} connected successfully!`); + await refetchAllConnectors(); } else { toast.dismiss("auto-index"); const config = validateIndexingConfigState({ From 2f4052aa71cfea2ea1d77ba9815eca4634b491ca Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:15:06 +0200 Subject: [PATCH 075/299] use pre-configured credentials for Airtable MCP OAuth --- surfsense_backend/app/services/mcp_oauth/registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index e6a9d20a5..769f2c88a 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -52,6 +52,9 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { mcp_url="https://mcp.airtable.com/mcp", connector_type="AIRTABLE_CONNECTOR", oauth_discovery_origin="https://airtable.com", + supports_dcr=False, + client_id_env="AIRTABLE_CLIENT_ID", + client_secret_env="AIRTABLE_CLIENT_SECRET", ), } From 0cc2475f6b766f990ff49cca1903c3305c035543 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:21:19 +0200 Subject: [PATCH 076/299] add required OAuth scopes for Airtable MCP --- surfsense_backend/app/services/mcp_oauth/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 769f2c88a..173fcf49d 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -55,6 +55,7 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { supports_dcr=False, client_id_env="AIRTABLE_CLIENT_ID", client_secret_env="AIRTABLE_CLIENT_SECRET", + scopes=["data.records:read", "data.records:write", "schema.bases:read", "schema.bases:write"], ), } From 225236e6f1d4a5de2a11280321cb213d4d22471b Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:35:15 +0200 Subject: [PATCH 077/299] add required OAuth scopes for Slack MCP --- surfsense_backend/app/services/mcp_oauth/registry.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 173fcf49d..ea7832f70 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -46,6 +46,14 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { supports_dcr=False, client_id_env="SLACK_CLIENT_ID", client_secret_env="SLACK_CLIENT_SECRET", + scopes=[ + "search:read.public", "search:read.private", "search:read.mpim", + "search:read.im", "search:read.files", "search:read.users", + "chat:write", "channels:history", "groups:history", + "mpim:history", "im:history", + "canvases:read", "canvases:write", + "users:read", "users:read.email", + ], ), "airtable": MCPServiceConfig( name="Airtable", From 3638d72b298e2cebab7cce4d46f80b7bce787d08 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:41:19 +0200 Subject: [PATCH 078/299] restore full Slack MCP scopes for all MCP tools --- surfsense_backend/app/services/mcp_oauth/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index ea7832f70..4d87ceb40 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -49,8 +49,8 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { scopes=[ "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", "search:read.files", "search:read.users", - "chat:write", "channels:history", "groups:history", - "mpim:history", "im:history", + "chat:write", + "channels:history", "groups:history", "mpim:history", "im:history", "canvases:read", "canvases:write", "users:read", "users:read.email", ], From 820326e3ee53386cc5c6605e00d4602cb57c7b16 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:54:16 +0200 Subject: [PATCH 079/299] use user_scope param for Slack OAuth --- surfsense_backend/app/routes/mcp_oauth_route.py | 4 ++-- surfsense_backend/app/services/mcp_oauth/registry.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index f7164eab3..98ca2be0f 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -165,7 +165,7 @@ async def connect_mcp_service( "state": state, } if svc.scopes: - auth_params["scope"] = " ".join(svc.scopes) + auth_params[svc.scope_param] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -478,7 +478,7 @@ async def reauth_mcp_service( "state": state, } if svc.scopes: - auth_params["scope"] = " ".join(svc.scopes) + auth_params[svc.scope_param] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 4d87ceb40..62eb2077f 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -21,6 +21,7 @@ class MCPServiceConfig: client_id_env: str | None = None client_secret_env: str | None = None scopes: list[str] = field(default_factory=list) + scope_param: str = "scope" MCP_SERVICES: dict[str, MCPServiceConfig] = { @@ -46,6 +47,7 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { supports_dcr=False, client_id_env="SLACK_CLIENT_ID", client_secret_env="SLACK_CLIENT_SECRET", + scope_param="user_scope", scopes=[ "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", "search:read.files", "search:read.users", From 970f62278b3677541526e6eeba78bf27bb15cbe0 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 09:56:18 +0200 Subject: [PATCH 080/299] revert scope_param, use standard scope for Slack v2_user endpoint --- surfsense_backend/app/routes/mcp_oauth_route.py | 4 ++-- surfsense_backend/app/services/mcp_oauth/registry.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index 98ca2be0f..f7164eab3 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -165,7 +165,7 @@ async def connect_mcp_service( "state": state, } if svc.scopes: - auth_params[svc.scope_param] = " ".join(svc.scopes) + auth_params["scope"] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -478,7 +478,7 @@ async def reauth_mcp_service( "state": state, } if svc.scopes: - auth_params[svc.scope_param] = " ".join(svc.scopes) + auth_params["scope"] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 62eb2077f..4d87ceb40 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -21,7 +21,6 @@ class MCPServiceConfig: client_id_env: str | None = None client_secret_env: str | None = None scopes: list[str] = field(default_factory=list) - scope_param: str = "scope" MCP_SERVICES: dict[str, MCPServiceConfig] = { @@ -47,7 +46,6 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { supports_dcr=False, client_id_env="SLACK_CLIENT_ID", client_secret_env="SLACK_CLIENT_SECRET", - scope_param="user_scope", scopes=[ "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", "search:read.files", "search:read.users", From dde1948a5c8782d96e9a478518940439f1114373 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 10:06:12 +0200 Subject: [PATCH 081/299] fix Slack MCP OAuth: v2 endpoint, user_scope param, nested token extraction --- .../app/routes/mcp_oauth_route.py | 30 ++++++++++++------- .../app/services/mcp_oauth/registry.py | 6 ++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index f7164eab3..efe928fd1 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -107,8 +107,8 @@ async def connect_mcp_service( metadata = await discover_oauth_metadata( svc.mcp_url, origin_override=svc.oauth_discovery_origin, ) - auth_endpoint = metadata.get("authorization_endpoint") - token_endpoint = metadata.get("token_endpoint") + auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") + token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") if not auth_endpoint or not token_endpoint: @@ -165,7 +165,7 @@ async def connect_mcp_service( "state": state, } if svc.scopes: - auth_params["scope"] = " ".join(svc.scopes) + auth_params[svc.scope_param] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -253,17 +253,27 @@ async def mcp_oauth_callback( ) access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + expires_in = token_json.get("expires_in") + scope = token_json.get("scope") + + if not access_token and "authed_user" in token_json: + authed = token_json["authed_user"] + access_token = authed.get("access_token") + refresh_token = refresh_token or authed.get("refresh_token") + scope = scope or authed.get("scope") + expires_in = expires_in or authed.get("expires_in") + if not access_token: raise HTTPException( status_code=400, detail=f"No access token received from {svc.name}.", ) - refresh_token = token_json.get("refresh_token") expires_at = None - if token_json.get("expires_in"): + if expires_in: expires_at = datetime.now(UTC) + timedelta( - seconds=int(token_json["expires_in"]) + seconds=int(expires_in) ) connector_config = { @@ -280,7 +290,7 @@ async def mcp_oauth_callback( "access_token": enc.encrypt_token(access_token), "refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None, "expires_at": expires_at.isoformat() if expires_at else None, - "scope": token_json.get("scope"), + "scope": scope, }, "_token_encrypted": True, } @@ -415,8 +425,8 @@ async def reauth_mcp_service( metadata = await discover_oauth_metadata( svc.mcp_url, origin_override=svc.oauth_discovery_origin, ) - auth_endpoint = metadata.get("authorization_endpoint") - token_endpoint = metadata.get("token_endpoint") + auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") + token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") if not auth_endpoint or not token_endpoint: @@ -478,7 +488,7 @@ async def reauth_mcp_service( "state": state, } if svc.scopes: - auth_params["scope"] = " ".join(svc.scopes) + auth_params[svc.scope_param] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 4d87ceb40..df6c6bb18 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -21,6 +21,9 @@ class MCPServiceConfig: client_id_env: str | None = None client_secret_env: str | None = None scopes: list[str] = field(default_factory=list) + scope_param: str = "scope" + auth_endpoint_override: str | None = None + token_endpoint_override: str | None = None MCP_SERVICES: dict[str, MCPServiceConfig] = { @@ -46,6 +49,9 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { supports_dcr=False, client_id_env="SLACK_CLIENT_ID", client_secret_env="SLACK_CLIENT_SECRET", + scope_param="user_scope", + auth_endpoint_override="https://slack.com/oauth/v2/authorize", + token_endpoint_override="https://slack.com/api/oauth.v2.access", scopes=[ "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", "search:read.files", "search:read.users", From dfa40b88018e09f1e4f743d1cedd8e1bb4744441 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 10:50:43 +0200 Subject: [PATCH 082/299] fix MCP OAuth for all 5 services, add MCP connector edit view --- .../app/routes/mcp_oauth_route.py | 4 +-- .../app/services/mcp_oauth/registry.py | 10 +++---- .../components/mcp-service-config.tsx | 30 +++++++++++++++++++ .../views/connector-edit-view.tsx | 25 ++++++++++------ 4 files changed, 53 insertions(+), 16 deletions(-) create mode 100644 surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index efe928fd1..b7c605089 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -128,7 +128,7 @@ async def connect_mcp_service( status_code=502, detail=f"DCR for {svc.name} did not return a client_id.", ) - elif not svc.supports_dcr and svc.client_id_env: + elif svc.client_id_env: client_id = getattr(config, svc.client_id_env, None) client_secret = getattr(config, svc.client_secret_env or "", None) or "" if not client_id: @@ -446,7 +446,7 @@ async def reauth_mcp_service( status_code=502, detail=f"DCR for {svc.name} did not return a client_id.", ) - elif not svc.supports_dcr and svc.client_id_env: + elif svc.client_id_env: client_id = getattr(config, svc.client_id_env, None) client_secret = getattr(config, svc.client_secret_env or "", None) or "" if not client_id: diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index df6c6bb18..cd1a0ae8c 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -1,9 +1,9 @@ -"""Registry of MCP services with OAuth 2.1 support. +"""Registry of MCP services with OAuth support. Each entry maps a URL-safe service key to its MCP server endpoint and -authentication strategy. Services with ``supports_dcr=True`` will use -RFC 7591 Dynamic Client Registration; the rest require pre-configured -credentials via environment variables. +authentication configuration. Services with ``supports_dcr=True`` use +RFC 7591 Dynamic Client Registration (the MCP server issues its own +credentials); the rest use pre-configured credentials via env vars. """ from __future__ import annotations @@ -65,8 +65,8 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { name="Airtable", mcp_url="https://mcp.airtable.com/mcp", connector_type="AIRTABLE_CONNECTOR", - oauth_discovery_origin="https://airtable.com", supports_dcr=False, + oauth_discovery_origin="https://airtable.com", client_id_env="AIRTABLE_CLIENT_ID", client_secret_env="AIRTABLE_CLIENT_SECRET", scopes=["data.records:read", "data.records:write", "schema.bases:read", "schema.bases:write"], diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx new file mode 100644 index 000000000..4f43694ad --- /dev/null +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx @@ -0,0 +1,30 @@ +"use client"; + +import { CheckCircle2 } from "lucide-react"; +import type { FC } from "react"; +import type { ConnectorConfigProps } from "../index"; + +export const MCPServiceConfig: FC = ({ connector }) => { + const serviceName = connector.config?.mcp_service as string | undefined; + + return ( +
+
+
+ +
+
+

Connected via MCP

+

+ Your agent can search, read, and take actions in{" "} + {serviceName + ? serviceName.charAt(0).toUpperCase() + serviceName.slice(1) + : "this service"}{" "} + in real time. No background indexing needed. +

+
+
+ +
+ ); +}; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index e19600ab2..3c92320da 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -17,7 +17,7 @@ import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; -import { getConnectorConfigComponent } from "../index"; +import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; const REAUTH_ENDPOINTS: Partial> = { [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", @@ -118,11 +118,16 @@ export const ConnectorEditView: FC = ({ } }, [searchSpaceId, searchSpaceIdAtom, reauthEndpoint, connector.id]); - // Get connector-specific config component - const ConnectorConfigComponent = useMemo( - () => getConnectorConfigComponent(connector.connector_type), - [connector.connector_type] - ); + const isMCPBacked = Boolean(connector.config?.server_config); + + // Get connector-specific config component (MCP-backed connectors use a generic view) + const ConnectorConfigComponent = useMemo(() => { + if (isMCPBacked) { + const { MCPServiceConfig } = require("../components/mcp-service-config"); + return MCPServiceConfig as FC; + } + return getConnectorConfigComponent(connector.connector_type); + }, [connector.connector_type, isMCPBacked]); const [isScrolled, setIsScrolled] = useState(false); const [hasMoreContent, setHasMoreContent] = useState(false); const [showDisconnectConfirm, setShowDisconnectConfirm] = useState(false); @@ -223,7 +228,9 @@ export const ConnectorEditView: FC = ({ {getConnectorDisplayName(connector.name)}

- Manage your connector settings and sync configuration + {isMCPBacked + ? "Connected — your agent can interact with this service in real time" + : "Manage your connector settings and sync configuration"}

@@ -421,7 +428,7 @@ export const ConnectorEditView: FC = ({ Re-authenticate - ) : ( + ) : !isMCPBacked ? ( - )} + ) : null} ); From a4bc621c2acae3a1305da77c3ff8046d7ab40c68 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 11:22:04 +0200 Subject: [PATCH 083/299] uniform connector UX across all connector types --- .../components/connector-card.tsx | 16 ++++---- .../components/discord-config.tsx | 17 ++++----- .../components/mcp-service-config.tsx | 14 +++---- .../components/teams-config.tsx | 6 +-- .../views/connector-edit-view.tsx | 14 ++++--- .../constants/connector-constants.ts | 37 ++++++++++++++----- .../tabs/active-connectors-tab.tsx | 14 +++++-- .../views/connector-accounts-list-view.tsx | 25 +++++-------- 8 files changed, 82 insertions(+), 61 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx b/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx index d24057b1c..e0df73e66 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx @@ -8,6 +8,7 @@ import { Spinner } from "@/components/ui/spinner"; import { EnumConnectorName } from "@/contracts/enums/connector"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { cn } from "@/lib/utils"; +import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { ConnectorStatusBadge } from "./connector-status-badge"; @@ -55,6 +56,7 @@ export const ConnectorCard: FC = ({ onManage, }) => { const isMCP = connectorType === EnumConnectorName.MCP_CONNECTOR; + const isLive = !!connectorType && LIVE_CONNECTOR_TYPES.has(connectorType); // Get connector status const { getConnectorStatus, isConnectorEnabled, getConnectorStatusMessage, shouldShowWarnings } = useConnectorStatus(); @@ -123,14 +125,14 @@ export const ConnectorCard: FC = ({ ) : ( <> - {formatDocumentCount(documentCount)} + {!isLive && {formatDocumentCount(documentCount)}} + {!isLive && accountCount !== undefined && accountCount > 0 && ( + + )} {accountCount !== undefined && accountCount > 0 && ( - <> - - - {accountCount} {accountCount === 1 ? "Account" : "Accounts"} - - + + {accountCount} {accountCount === 1 ? "Account" : "Accounts"} + )} )} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx index f782a6f4d..c8714ba40 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx @@ -53,8 +53,7 @@ export const DiscordConfig: FC = ({ connector }) => { return () => document.removeEventListener("visibilitychange", handleVisibilityChange); }, [connector?.id, fetchChannels]); - // Separate channels by indexing capability - const readyToIndex = channels.filter((ch) => ch.can_index); + const accessible = channels.filter((ch) => ch.can_index); const needsPermissions = channels.filter((ch) => !ch.can_index); // Format last fetched time @@ -80,7 +79,7 @@ export const DiscordConfig: FC = ({ connector }) => {

- The bot needs "Read Message History" permission to index channels. Ask a + The bot needs "Read Message History" permission to access channels. Ask a server admin to grant this permission for channels shown below.

@@ -127,18 +126,18 @@ export const DiscordConfig: FC = ({ connector }) => { ) : (
- {/* Ready to index */} - {readyToIndex.length > 0 && ( + {/* Accessible channels */} + {accessible.length > 0 && (
0 && "border-b border-border")}>
- Ready to index + Accessible - {readyToIndex.length} {readyToIndex.length === 1 ? "channel" : "channels"} + {accessible.length} {accessible.length === 1 ? "channel" : "channels"}
- {readyToIndex.map((channel) => ( + {accessible.map((channel) => ( ))}
@@ -150,7 +149,7 @@ export const DiscordConfig: FC = ({ connector }) => {
- Grant permissions to index + Needs permissions {needsPermissions.length}{" "} {needsPermissions.length === 1 ? "channel" : "channels"} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx index 4f43694ad..71d0e31a8 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx @@ -6,25 +6,23 @@ import type { ConnectorConfigProps } from "../index"; export const MCPServiceConfig: FC = ({ connector }) => { const serviceName = connector.config?.mcp_service as string | undefined; + const displayName = serviceName + ? serviceName.charAt(0).toUpperCase() + serviceName.slice(1) + : "this service"; return (
-
+
-

Connected via MCP

+

Connected

- Your agent can search, read, and take actions in{" "} - {serviceName - ? serviceName.charAt(0).toUpperCase() + serviceName.slice(1) - : "this service"}{" "} - in real time. No background indexing needed. + Your agent can search, read, and take actions in {displayName}.

-
); }; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx index ac08a6c03..e96ddfd29 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx @@ -18,9 +18,9 @@ export const TeamsConfig: FC = () => {

Microsoft Teams Access

- SurfSense will index messages from Teams channels that you have access to. The app can - only read messages from teams and channels where you are a member. Make sure you're a - member of the teams you want to index before connecting. + Your agent can search and read messages from Teams channels you have access to, + and send messages on your behalf. Make sure you're a member of the teams + you want to interact with.

diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index 3c92320da..aa3c8d193 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -16,6 +16,7 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; +import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; @@ -119,6 +120,7 @@ export const ConnectorEditView: FC = ({ }, [searchSpaceId, searchSpaceIdAtom, reauthEndpoint, connector.id]); const isMCPBacked = Boolean(connector.config?.server_config); + const isLive = isMCPBacked || LIVE_CONNECTOR_TYPES.has(connector.connector_type); // Get connector-specific config component (MCP-backed connectors use a generic view) const ConnectorConfigComponent = useMemo(() => { @@ -228,8 +230,8 @@ export const ConnectorEditView: FC = ({ {getConnectorDisplayName(connector.name)}

- {isMCPBacked - ? "Connected — your agent can interact with this service in real time" + {isLive + ? "Manage your connected account" : "Manage your connector settings and sync configuration"}

@@ -381,10 +383,12 @@ export const ConnectorEditView: FC = ({ {/* Fixed Footer - Action buttons */}
- {showDisconnectConfirm ? ( -
+ {showDisconnectConfirm ? ( +
- Are you sure? + {isLive + ? "Your agent will lose access to this service." + : "This will remove all indexed data."}
@@ -234,15 +231,13 @@ export const ConnectorAccountsListView: FC = ({ Syncing

- ) : ( -

- {isIndexableConnector(connector.connector_type) - ? connector.last_indexed_at - ? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}` - : "Never indexed" - : "Active"} + ) : !isLiveConnector(connector.connector_type) ? ( +

+ {connector.last_indexed_at + ? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}` + : "Never indexed"}

- )} + ) : null}
{isAuthExpired ? ( )} - {isPluginManagedReadOnly ? null : isAuthExpired && reauthEndpoint ? ( + {isAuthExpired && reauthEndpoint ? ( diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index daed8747d..a341581b4 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -1541,10 +1541,13 @@ function AnonymousDocumentsSidebar({ type="button" onClick={handleAnonUploadClick} disabled={isUploading} - className="flex w-full items-center justify-center gap-2 rounded-lg border-2 border-dashed border-primary/30 px-4 py-6 text-sm text-primary transition-colors hover:border-primary/60 hover:bg-primary/5 cursor-pointer disabled:opacity-50 disabled:pointer-events-none" + className="relative flex w-full items-center justify-center rounded-lg border-2 border-dashed border-primary/30 px-4 py-6 text-sm text-primary transition-colors hover:border-primary/60 hover:bg-primary/5 cursor-pointer disabled:opacity-50 disabled:pointer-events-none" > - - {isUploading ? "Uploading..." : "Upload a document"} + + + Upload a document + + {isUploading && }

Text, code, CSV, and HTML files only. Create an account for PDFs, images, and 30+ diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index 026f3afc3..1ee5cd165 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -1,6 +1,6 @@ "use client"; -import { Download, FileQuestionMark, FileText, Loader2, PenLine, RefreshCw } from "lucide-react"; +import { Download, FileQuestionMark, FileText, PenLine, RefreshCw } from "lucide-react"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; @@ -8,6 +8,7 @@ import { PlateEditor } from "@/components/editor/plate-editor"; import { MarkdownViewer } from "@/components/markdown-viewer"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB @@ -278,7 +279,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index 5a324fea9..3d2b2d7db 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -763,22 +763,16 @@ export function DocumentUploadTab({

)} From a1d03da896efe3d50e172c5f73ac40e07eb81e04 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 20:08:19 +0200 Subject: [PATCH 097/299] fix: encrypt tokens at rest, invalidate cache on refresh, clean up logging --- .../app/agents/new_chat/tools/mcp_tool.py | 144 ++++++++++++------ .../app/agents/new_chat/tools/registry.py | 24 +-- .../app/routes/mcp_oauth_route.py | 10 +- 3 files changed, 103 insertions(+), 75 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 62ef56dd7..25b1b3e74 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -22,7 +22,7 @@ from typing import Any from langchain_core.tools import StructuredTool from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel, create_model +from pydantic import BaseModel, Field, create_model from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession @@ -66,18 +66,14 @@ def _create_dynamic_input_model_from_schema( param_description = param_schema.get("description", "") is_required = param_name in required_fields - from typing import Any as AnyType - - from pydantic import Field - if is_required: field_definitions[param_name] = ( - AnyType, + Any, Field(..., description=param_description), ) else: field_definitions[param_name] = ( - AnyType | None, + Any | None, Field(None, description=param_description), ) @@ -103,13 +99,13 @@ async def _create_mcp_tool_from_definition_stdio( tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") + logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema) input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_tool_call(**kwargs) -> str: """Execute the MCP tool call via the client with retry support.""" - logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + logger.debug("MCP tool '%s' called", tool_name) # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph hitl_result = request_approval( @@ -133,13 +129,11 @@ async def _create_mcp_tool_from_definition_stdio( result = await mcp_client.call_tool(tool_name, call_kwargs) return str(result) except RuntimeError as e: - error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}" - logger.error(error_msg) - return f"Error: {error_msg}" + logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e) + return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}" except Exception as e: - error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}" - logger.exception(error_msg) - return f"Error: {error_msg}" + logger.exception("MCP tool '%s' execution failed: %s", tool_name, e) + return f"Error: MCP tool '{tool_name}' execution failed: {e!s}" tool = StructuredTool( name=tool_name, @@ -154,7 +148,7 @@ async def _create_mcp_tool_from_definition_stdio( }, ) - logger.info(f"Created MCP tool (stdio): '{tool_name}'") + logger.debug("Created MCP tool (stdio): '%s'", tool_name) return tool @@ -191,13 +185,13 @@ async def _create_mcp_tool_from_definition_http( if tool_name_prefix: tool_description = f"[Account: {connector_name}] {tool_description}" - logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}") + logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema) input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema) async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" - logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}") + logger.debug("MCP HTTP tool '%s' called", exposed_name) if is_readonly: call_kwargs = {k: v for k, v in kwargs.items() if v is not None} @@ -238,15 +232,12 @@ async def _create_mcp_tool_from_definition_http( result.append(str(content)) result_str = "\n".join(result) if result else "" - logger.info( - f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}" - ) + logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) return result_str except Exception as e: - error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}" - logger.exception(error_msg) - return f"Error: {error_msg}" + logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e) + return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}" tool = StructuredTool( name=exposed_name, @@ -264,7 +255,7 @@ async def _create_mcp_tool_from_definition_http( }, ) - logger.info(f"Created MCP tool (HTTP): '{exposed_name}'") + logger.debug("Created MCP tool (HTTP): '%s'", exposed_name) return tool @@ -280,21 +271,24 @@ async def _load_stdio_mcp_tools( command = server_config.get("command") if not command or not isinstance(command, str): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping" + "MCP connector %d (name: '%s') missing or invalid command field, skipping", + connector_id, connector_name, ) return tools args = server_config.get("args", []) if not isinstance(args, list): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping" + "MCP connector %d (name: '%s') has invalid args field (must be list), skipping", + connector_id, connector_name, ) return tools env = server_config.get("env", {}) if not isinstance(env, dict): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping" + "MCP connector %d (name: '%s') has invalid env field (must be dict), skipping", + connector_id, connector_name, ) return tools @@ -304,8 +298,8 @@ async def _load_stdio_mcp_tools( tool_definitions = await mcp_client.list_tools() logger.info( - f"Discovered {len(tool_definitions)} tools from stdio MCP server " - f"'{command}' (connector {connector_id})" + "Discovered %d tools from stdio MCP server '%s' (connector %d)", + len(tool_definitions), command, connector_id, ) for tool_def in tool_definitions: @@ -320,8 +314,8 @@ async def _load_stdio_mcp_tools( tools.append(tool) except Exception as e: logger.exception( - f"Failed to create tool '{tool_def.get('name')}' " - f"from connector {connector_id}: {e!s}" + "Failed to create tool '%s' from connector %d: %s", + tool_def.get("name"), connector_id, e, ) return tools @@ -351,14 +345,16 @@ async def _load_http_mcp_tools( url = server_config.get("url") if not url or not isinstance(url, str): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping" + "MCP connector %d (name: '%s') missing or invalid url field, skipping", + connector_id, connector_name, ) return tools headers = server_config.get("headers", {}) if not isinstance(headers, dict): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping" + "MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping", + connector_id, connector_name, ) return tools @@ -415,13 +411,14 @@ async def _load_http_mcp_tools( tools.append(tool) except Exception as e: logger.exception( - f"Failed to create HTTP tool '{tool_def.get('name')}' " - f"from connector {connector_id}: {e!s}" + "Failed to create HTTP tool '%s' from connector %d: %s", + tool_def.get("name"), connector_id, e, ) except Exception as e: logger.exception( - f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}" + "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", + url, connector_id, e, ) return tools @@ -430,6 +427,42 @@ async def _load_http_mcp_tools( _TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry +def _inject_oauth_headers( + cfg: dict[str, Any], + server_config: dict[str, Any], +) -> dict[str, Any]: + """Decrypt the MCP OAuth access token and inject it into server_config headers. + + The DB never stores plaintext tokens in ``server_config.headers``. This + function decrypts ``mcp_oauth.access_token`` at runtime and returns a + *copy* of ``server_config`` with the Authorization header set. + """ + mcp_oauth = cfg.get("mcp_oauth", {}) + encrypted_token = mcp_oauth.get("access_token") + if not encrypted_token: + return server_config + + try: + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + enc = TokenEncryption(app_config.SECRET_KEY) + access_token = enc.decrypt_token(encrypted_token) + + result = dict(server_config) + result["headers"] = { + **server_config.get("headers", {}), + "Authorization": f"Bearer {access_token}", + } + return result + except Exception: + logger.warning( + "Failed to decrypt MCP OAuth token for runtime injection", + exc_info=True, + ) + return server_config + + async def _maybe_refresh_mcp_oauth_token( session: AsyncSession, connector: "SearchSourceConnector", @@ -510,17 +543,11 @@ async def _maybe_refresh_mcp_oauth_token( new_expires_at.isoformat() if new_expires_at else None ) - updated_server_config = dict(server_config) - updated_server_config["headers"] = { - **server_config.get("headers", {}), - "Authorization": f"Bearer {new_access}", - } - from sqlalchemy.orm.attributes import flag_modified connector.config = { **cfg, - "server_config": updated_server_config, + "server_config": server_config, "mcp_oauth": updated_oauth, } flag_modified(connector, "config") @@ -528,7 +555,17 @@ async def _maybe_refresh_mcp_oauth_token( await session.refresh(connector) logger.info("Refreshed MCP OAuth token for connector %s", connector.id) - return updated_server_config + + # Invalidate cache so next call picks up the new token. + invalidate_mcp_tools_cache(connector.search_space_id) + + # Return server_config with the fresh token injected for immediate use. + refreshed_config = dict(server_config) + refreshed_config["headers"] = { + **server_config.get("headers", {}), + "Authorization": f"Bearer {new_access}", + } + return refreshed_config except Exception: logger.warning( @@ -622,15 +659,21 @@ async def load_mcp_tools( if not server_config or not isinstance(server_config, dict): logger.warning( - f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping" + "MCP connector %d (name: '%s') has invalid or missing server_config, skipping", + connector.id, connector.name, ) continue - # Refresh OAuth token for MCP OAuth connectors before connecting + # For MCP OAuth connectors: refresh if needed, then decrypt the + # access token and inject it into headers at runtime. The DB + # intentionally does NOT store plaintext tokens in server_config. if cfg.get("mcp_oauth"): server_config = await _maybe_refresh_mcp_oauth_token( session, connector, cfg, server_config, ) + # Re-read cfg after potential refresh (connector was reloaded from DB). + cfg = connector.config or {} + server_config = _inject_oauth_headers(cfg, server_config) ct = ( connector.connector_type.value @@ -677,7 +720,8 @@ async def load_mcp_tools( except Exception as e: logger.exception( - f"Failed to load tools from MCP connector {connector.id}: {e!s}" + "Failed to load tools from MCP connector %d: %s", + connector.id, e, ) _mcp_tools_cache[search_space_id] = (now, tools) @@ -686,9 +730,9 @@ async def load_mcp_tools( oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) del _mcp_tools_cache[oldest_key] - logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") + logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id) return tools except Exception as e: - logger.exception(f"Failed to load MCP tools: {e!s}") + logger.exception("Failed to load MCP tools: %s", e) return [] diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 5616d4f9a..85c89b114 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -78,11 +78,7 @@ from .google_drive import ( create_create_google_drive_file_tool, create_delete_google_drive_file_tool, ) -# NOTE: Native Jira CRUD tools (create/update/delete_jira_issue) have been -# replaced by MCP equivalents (createJiraIssue, editJiraIssue). The native -# tools used the REST API which is incompatible with MCP-scoped OAuth tokens. from .connected_accounts import create_get_connected_accounts_tool -# NOTE: Native Linear delete tool disabled — see comment in BUILTIN_TOOLS. from .luma import ( create_create_luma_event_tool, create_list_luma_events_tool, @@ -279,12 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ], ), # ========================================================================= - # LINEAR TOOLS — create/update handled by MCP save_issue. Delete/archive - # is NOT available: the official Linear MCP server does not expose a - # delete tool, and the native tool's GraphQL API call fails with - # MCP-scoped tokens (401). Re-enable when Linear adds MCP delete support. - # ========================================================================= - # ========================================================================= # NOTION TOOLS - create, update, delete pages # Auto-disabled when no Notion connector is configured (see chat_deepagent.py) # ========================================================================= @@ -518,11 +508,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ required_connector="GOOGLE_GMAIL_CONNECTOR", ), # ========================================================================= - # JIRA TOOLS — Now fully handled by MCP (createJiraIssue, editJiraIssue, - # searchJiraIssuesUsingJql, etc.). Native tools removed because the - # MCP-scoped OAuth token cannot call the Jira REST API. - # ========================================================================= - # ========================================================================= # CONFLUENCE TOOLS - create, update, delete pages # Auto-disabled when no Confluence connector is configured (see chat_deepagent.py) # ========================================================================= @@ -843,14 +828,15 @@ async def build_tools_async( ) tools.extend(mcp_tools) logging.info( - f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}", + "Registered %d MCP tools: %s", + len(mcp_tools), [t.name for t in mcp_tools], ) except Exception as e: - # Log error but don't fail - just continue without MCP tools - logging.exception(f"Failed to load MCP tools: {e!s}") + logging.exception("Failed to load MCP tools: %s", e) logging.info( - f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}", + "Total tools for agent: %d — %s", + len(tools), [t.name for t in tools], ) return tools diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index 79d9dba93..07371873e 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -79,16 +79,15 @@ async def _fetch_account_metadata( "https://api.airtable.com/v0/meta/whoami", headers={"Authorization": f"Bearer {access_token}"}, ) - if resp.status_code != 200: - logger.warning( - "Airtable whoami API response: status=%s body=%s", - resp.status_code, resp.text[:300], - ) if resp.status_code == 200: whoami = resp.json() meta["user_id"] = whoami.get("id", "") meta["user_email"] = whoami.get("email", "") meta["display_name"] = whoami.get("email", "Airtable") + else: + logger.warning( + "Airtable whoami API returned %d (non-blocking)", resp.status_code, + ) except Exception: logger.warning( @@ -346,7 +345,6 @@ async def mcp_oauth_callback( "server_config": { "transport": "streamable-http", "url": mcp_url, - "headers": {"Authorization": f"Bearer {access_token}"}, }, "mcp_service": svc_key, "mcp_oauth": { From 01153b0d7e72b7634a8919e5afeca278ae6d53c7 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 20:24:45 +0200 Subject: [PATCH 098/299] fix: cache TokenEncryption, clear stale router caches on re-init --- .../app/agents/new_chat/tools/mcp_tool.py | 29 +++++++++++++------ .../app/services/llm_router_service.py | 6 ++++ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 25b1b3e74..950109afd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -14,10 +14,15 @@ clicking "Always Allow", which adds the tool name to the connector's ``config.trusted_tools`` allow-list. """ +from __future__ import annotations + import logging import time from collections import defaultdict -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from app.utils.oauth_security import TokenEncryption from langchain_core.tools import StructuredTool from mcp import ClientSession @@ -426,6 +431,18 @@ async def _load_http_mcp_tools( _TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry +_token_enc: TokenEncryption | None = None + + +def _get_token_enc() -> TokenEncryption: + global _token_enc + if _token_enc is None: + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + _token_enc = TokenEncryption(app_config.SECRET_KEY) + return _token_enc + def _inject_oauth_headers( cfg: dict[str, Any], @@ -443,11 +460,7 @@ def _inject_oauth_headers( return server_config try: - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - enc = TokenEncryption(app_config.SECRET_KEY) - access_token = enc.decrypt_token(encrypted_token) + access_token = _get_token_enc().decrypt_token(encrypted_token) result = dict(server_config) result["headers"] = { @@ -500,11 +513,9 @@ async def _maybe_refresh_mcp_oauth_token( return server_config try: - from app.config import config as app_config from app.services.mcp_oauth.discovery import refresh_access_token - from app.utils.oauth_security import TokenEncryption - enc = TokenEncryption(app_config.SECRET_KEY) + enc = _get_token_enc() decrypted_refresh = enc.decrypt_token(refresh_token) decrypted_secret = ( enc.decrypt_token(mcp_oauth["client_secret"]) diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index c9eeff01b..4bce79a43 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -290,6 +290,12 @@ class LLMRouterService: instance._router = Router(**router_kwargs) instance._initialized = True + + global _cached_context_profile, _cached_context_profile_computed + _cached_context_profile = None + _cached_context_profile_computed = False + _router_instance_cache.clear() + logger.info( "LLM Router initialized with %d deployments, " "strategy: %s, context_window_fallbacks: %s, fallbacks: %s", From 0eae96bffbe6a445fda7b3d70af1b0ecda826d57 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 20:54:42 +0200 Subject: [PATCH 099/299] fix: harden MCP OAuth and connector edge cases --- .../app/agents/new_chat/tools/hitl.py | 4 ++-- .../app/agents/new_chat/tools/mcp_tool.py | 18 ++++++++++++------ .../app/routes/mcp_oauth_route.py | 7 ++++++- .../app/utils/connector_naming.py | 8 +++++--- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 64ace547c..89f02abf6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -130,8 +130,8 @@ def request_approval( try: decision_type, edited_params = _parse_decision(approval) except ValueError: - logger.warning("No approval decision received for %s", tool_name) - return HITLResult(rejected=False, decision_type="error", params=params) + logger.warning("No approval decision received for %s — rejecting for safety", tool_name) + return HITLResult(rejected=True, decision_type="error", params=params) logger.info("User decision for %s: %s", tool_name, decision_type) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 950109afd..8f8e5007f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -447,7 +447,7 @@ def _get_token_enc() -> TokenEncryption: def _inject_oauth_headers( cfg: dict[str, Any], server_config: dict[str, Any], -) -> dict[str, Any]: +) -> dict[str, Any] | None: """Decrypt the MCP OAuth access token and inject it into server_config headers. The DB never stores plaintext tokens in ``server_config.headers``. This @@ -469,11 +469,11 @@ def _inject_oauth_headers( } return result except Exception: - logger.warning( - "Failed to decrypt MCP OAuth token for runtime injection", + logger.error( + "Failed to decrypt MCP OAuth token — connector will be skipped", exc_info=True, ) - return server_config + return None async def _maybe_refresh_mcp_oauth_token( @@ -666,7 +666,6 @@ async def load_mcp_tools( try: cfg = connector.config or {} server_config = cfg.get("server_config", {}) - trusted_tools = cfg.get("trusted_tools", []) if not server_config or not isinstance(server_config, dict): logger.warning( @@ -685,6 +684,14 @@ async def load_mcp_tools( # Re-read cfg after potential refresh (connector was reloaded from DB). cfg = connector.config or {} server_config = _inject_oauth_headers(cfg, server_config) + if server_config is None: + logger.warning( + "Skipping MCP connector %d — OAuth token decryption failed", + connector.id, + ) + continue + + trusted_tools = cfg.get("trusted_tools", []) ct = ( connector.connector_type.value @@ -692,7 +699,6 @@ async def load_mcp_tools( else str(connector.connector_type) ) - # Resolve the allowlist from the service registry (if any). svc_cfg = get_service_by_connector_type(ct) allowed_tools = svc_cfg.allowed_tools if svc_cfg else [] readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset() diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index 07371873e..e14be83d0 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -361,7 +361,12 @@ async def mcp_oauth_callback( account_meta = await _fetch_account_metadata(svc_key, access_token, token_json) if account_meta: - connector_config.update(account_meta) + _SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email", + "workspace_id", "workspace_name", "organization_name", + "organization_url_key", "cloud_id", "site_name", "base_url"} + for k, v in account_meta.items(): + if k in _SAFE_META_KEYS: + connector_config[k] = v logger.info( "Stored account metadata for %s: display_name=%s", svc_key, account_meta.get("display_name", ""), diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 610be4a22..889bf1464 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = { def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str: """Get a friendly display name for a connector type.""" return BASE_NAME_FOR_TYPE.get( - connector_type, connector_type.replace("_", " ").title() + connector_type, connector_type.value.replace("_", " ").title() ) @@ -231,9 +231,11 @@ async def generate_unique_connector_name( base = get_base_name_for_type(connector_type) if identifier: - return f"{base} - {identifier}" + name = f"{base} - {identifier}" + return await ensure_unique_connector_name( + session, name, search_space_id, user_id, + ) - # Fallback: use counter for uniqueness count = await count_connectors_of_type( session, connector_type, search_space_id, user_id ) From 9977f9b6413235ae2c8cd052e18a2edd67a94339 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 21:43:51 +0200 Subject: [PATCH 100/299] remove dead indexing tasks and fix silent schedule breakage for live connectors --- surfsense_backend/app/celery_app.py | 8 - .../routes/search_source_connectors_routes.py | 418 ------------------ .../app/tasks/celery_tasks/connector_tasks.py | 347 --------------- .../celery_tasks/schedule_checker_task.py | 24 + .../app/tasks/connector_indexers/__init__.py | 56 +-- 5 files changed, 29 insertions(+), 824 deletions(-) diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index c44391528..e3a520c48 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -135,20 +135,12 @@ celery_app.conf.update( # never block fast user-facing tasks (file uploads, podcasts, etc.) task_routes={ # Connector indexing tasks → connectors queue - "index_slack_messages": {"queue": CONNECTORS_QUEUE}, "index_notion_pages": {"queue": CONNECTORS_QUEUE}, "index_github_repos": {"queue": CONNECTORS_QUEUE}, - "index_linear_issues": {"queue": CONNECTORS_QUEUE}, - "index_jira_issues": {"queue": CONNECTORS_QUEUE}, "index_confluence_pages": {"queue": CONNECTORS_QUEUE}, - "index_clickup_tasks": {"queue": CONNECTORS_QUEUE}, "index_google_calendar_events": {"queue": CONNECTORS_QUEUE}, - "index_airtable_records": {"queue": CONNECTORS_QUEUE}, "index_google_gmail_messages": {"queue": CONNECTORS_QUEUE}, "index_google_drive_files": {"queue": CONNECTORS_QUEUE}, - "index_discord_messages": {"queue": CONNECTORS_QUEUE}, - "index_teams_messages": {"queue": CONNECTORS_QUEUE}, - "index_luma_events": {"queue": CONNECTORS_QUEUE}, "index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE}, "index_crawled_urls": {"queue": CONNECTORS_QUEUE}, "index_bookstack_pages": {"queue": CONNECTORS_QUEUE}, diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 7ce3ca9a3..0c06318ee 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -1219,57 +1219,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id: await session.rollback() -async def run_slack_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Slack indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_slack_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_slack_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Slack indexing. - - Args: - session: Database session - connector_id: ID of the Slack connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_slack_messages - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_slack_messages, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - _AUTH_ERROR_PATTERNS = ( "failed to refresh linear oauth", "failed to refresh your notion connection", @@ -1808,215 +1757,6 @@ async def run_github_indexing( ) -# Add new helper functions for Linear indexing -async def run_linear_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run Linear indexing with its own database session.""" - logger.info( - f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_linear_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing Linear connector {connector_id}") - - -async def run_linear_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Linear indexing. - - Args: - session: Database session - connector_id: ID of the Linear connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_linear_issues - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_linear_issues, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -# Add new helper functions for discord indexing -async def run_discord_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Discord indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_discord_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_discord_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Discord indexing. - - Args: - session: Database session - connector_id: ID of the Discord connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_discord_messages - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_discord_messages, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -async def run_teams_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Microsoft Teams indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_teams_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_teams_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Microsoft Teams indexing. - - Args: - session: Database session - connector_id: ID of the Teams connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers.teams_indexer import index_teams_messages - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_teams_messages, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -# Add new helper functions for Jira indexing -async def run_jira_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run Jira indexing with its own database session.""" - logger.info( - f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_jira_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing Jira connector {connector_id}") - - -async def run_jira_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Jira indexing. - - Args: - session: Database session - connector_id: ID of the Jira connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_jira_issues - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_jira_issues, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - # Add new helper functions for Confluence indexing async def run_confluence_indexing_with_new_session( connector_id: int, @@ -2072,112 +1812,6 @@ async def run_confluence_indexing( ) -# Add new helper functions for ClickUp indexing -async def run_clickup_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run ClickUp indexing with its own database session.""" - logger.info( - f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_clickup_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}") - - -async def run_clickup_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run ClickUp indexing. - - Args: - session: Database session - connector_id: ID of the ClickUp connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_clickup_tasks - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_clickup_tasks, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -# Add new helper functions for Airtable indexing -async def run_airtable_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run Airtable indexing with its own database session.""" - logger.info( - f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_airtable_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing Airtable connector {connector_id}") - - -async def run_airtable_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Airtable indexing. - - Args: - session: Database session - connector_id: ID of the Airtable connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_airtable_records - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_airtable_records, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - # Add new helper functions for Google Calendar indexing async def run_google_calendar_indexing_with_new_session( connector_id: int, @@ -2716,58 +2350,6 @@ async def run_dropbox_indexing( logger.error(f"Failed to update notification: {notif_error!s}") -# Add new helper functions for luma indexing -async def run_luma_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Luma indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_luma_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_luma_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Luma indexing. - - Args: - session: Database session - connector_id: ID of the Luma connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_luma_events - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_luma_events, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - async def run_elasticsearch_indexing_with_new_session( connector_id: int, search_space_id: int, diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index 57475c9fd..141d5ffca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N ) -@celery_app.task(name="index_slack_messages", bind=True) -def index_slack_messages_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Slack messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_slack_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - except Exception as e: - _handle_greenlet_error(e, "index_slack_messages", connector_id) - raise - finally: - loop.close() - - -async def _index_slack_messages( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Slack messages with new session.""" - from app.routes.search_source_connectors_routes import ( - run_slack_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_slack_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_notion_pages", bind=True) def index_notion_pages_task( self, @@ -174,92 +128,6 @@ async def _index_github_repos( ) -@celery_app.task(name="index_linear_issues", bind=True) -def index_linear_issues_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Linear issues.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_linear_issues( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_linear_issues( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Linear issues with new session.""" - from app.routes.search_source_connectors_routes import ( - run_linear_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_linear_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -@celery_app.task(name="index_jira_issues", bind=True) -def index_jira_issues_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Jira issues.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_jira_issues( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_jira_issues( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Jira issues with new session.""" - from app.routes.search_source_connectors_routes import ( - run_jira_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_jira_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_confluence_pages", bind=True) def index_confluence_pages_task( self, @@ -303,49 +171,6 @@ async def _index_confluence_pages( ) -@celery_app.task(name="index_clickup_tasks", bind=True) -def index_clickup_tasks_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index ClickUp tasks.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_clickup_tasks( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_clickup_tasks( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index ClickUp tasks with new session.""" - from app.routes.search_source_connectors_routes import ( - run_clickup_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_clickup_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_google_calendar_events", bind=True) def index_google_calendar_events_task( self, @@ -392,49 +217,6 @@ async def _index_google_calendar_events( ) -@celery_app.task(name="index_airtable_records", bind=True) -def index_airtable_records_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Airtable records.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_airtable_records( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_airtable_records( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Airtable records with new session.""" - from app.routes.search_source_connectors_routes import ( - run_airtable_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_airtable_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_google_gmail_messages", bind=True) def index_google_gmail_messages_task( self, @@ -622,135 +404,6 @@ async def _index_dropbox_files( ) -@celery_app.task(name="index_discord_messages", bind=True) -def index_discord_messages_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Discord messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_discord_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_discord_messages( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Discord messages with new session.""" - from app.routes.search_source_connectors_routes import ( - run_discord_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_discord_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -@celery_app.task(name="index_teams_messages", bind=True) -def index_teams_messages_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Microsoft Teams messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_teams_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_teams_messages( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Microsoft Teams messages with new session.""" - from app.routes.search_source_connectors_routes import ( - run_teams_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_teams_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -@celery_app.task(name="index_luma_events", bind=True) -def index_luma_events_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Luma events.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_luma_events( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_luma_events( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Luma events with new session.""" - from app.routes.search_source_connectors_routes import ( - run_luma_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_luma_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_elasticsearch_documents", bind=True) def index_elasticsearch_documents_task( self, diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 3aee5a4ca..89010192f 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -77,8 +77,32 @@ async def _check_and_trigger_schedules(): SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, } + _LIVE_CONNECTOR_TYPES = { + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnectorType.LINEAR_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CLICKUP_CONNECTOR, + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.AIRTABLE_CONNECTOR, + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnectorType.LUMA_CONNECTOR, + } + # Trigger indexing for each due connector for connector in due_connectors: + if connector.connector_type in _LIVE_CONNECTOR_TYPES: + connector.periodic_indexing_enabled = False + connector.next_scheduled_at = None + await session.commit() + logger.info( + "Disabled obsolete periodic indexing for live connector %s (%s)", + connector.id, + connector.connector_type.value, + ) + continue + # Primary guard: Redis lock indicates a task is currently running. if is_connector_indexing_locked(connector.id): logger.info( diff --git a/surfsense_backend/app/tasks/connector_indexers/__init__.py b/surfsense_backend/app/tasks/connector_indexers/__init__.py index 1b032d54a..2b0ad7fa0 100644 --- a/surfsense_backend/app/tasks/connector_indexers/__init__.py +++ b/surfsense_backend/app/tasks/connector_indexers/__init__.py @@ -1,77 +1,31 @@ """ Connector indexers module for background tasks. -This module provides a collection of connector indexers for different platforms -and services. Each indexer is responsible for handling the indexing of content -from a specific connector type. - -Available indexers: -- Slack: Index messages from Slack channels -- Notion: Index pages from Notion workspaces -- GitHub: Index repositories and files from GitHub -- Linear: Index issues from Linear workspaces -- Jira: Index issues from Jira projects -- Confluence: Index pages from Confluence spaces -- BookStack: Index pages from BookStack wiki instances -- Discord: Index messages from Discord servers -- ClickUp: Index tasks from ClickUp workspaces -- Google Gmail: Index messages from Google Gmail -- Google Calendar: Index events from Google Calendar -- Luma: Index events from Luma -- Webcrawler: Index crawled URLs -- Elasticsearch: Index documents from Elasticsearch instances +Each indexer handles content indexing from a specific connector type. +Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams, +Luma) now use real-time agent tools instead of background indexing. """ -# Communication platforms -# Calendar and scheduling -from .airtable_indexer import index_airtable_records from .bookstack_indexer import index_bookstack_pages - -# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports -from .clickup_indexer import index_clickup_tasks from .confluence_indexer import index_confluence_pages -from .discord_indexer import index_discord_messages - -# Development platforms from .elasticsearch_indexer import index_elasticsearch_documents from .github_indexer import index_github_repos from .google_calendar_indexer import index_google_calendar_events from .google_drive_indexer import index_google_drive_files from .google_gmail_indexer import index_google_gmail_messages -from .jira_indexer import index_jira_issues - -# Issue tracking and project management -from .linear_indexer import index_linear_issues - -# Documentation and knowledge management -from .luma_indexer import index_luma_events from .notion_indexer import index_notion_pages from .obsidian_indexer import index_obsidian_vault -from .slack_indexer import index_slack_messages from .webcrawler_indexer import index_crawled_urls -__all__ = [ # noqa: RUF022 - "index_airtable_records", +__all__ = [ "index_bookstack_pages", - # "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports - "index_clickup_tasks", "index_confluence_pages", - "index_discord_messages", - # Development platforms "index_elasticsearch_documents", "index_github_repos", - # Calendar and scheduling "index_google_calendar_events", "index_google_drive_files", - "index_luma_events", - "index_jira_issues", - # Issue tracking and project management - "index_linear_issues", - # Documentation and knowledge management + "index_google_gmail_messages", "index_notion_pages", "index_obsidian_vault", "index_crawled_urls", - # Communication platforms - "index_slack_messages", - "index_google_gmail_messages", ] From b6c506abeff5157fc36f8635e7ef69a5c7209c6e Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 22 Apr 2026 22:07:55 +0200 Subject: [PATCH 101/299] fix: treat all Gmail/Calendar as live connectors, hide indexing UI --- .../new_chat/tools/gmail/search_emails.py | 31 ++++++++++--- .../tools/google_calendar/search_events.py | 36 +--------------- .../routes/search_source_connectors_routes.py | 16 ++----- .../app/services/composio_service.py | 2 +- .../app/services/mcp_oauth/registry.py | 17 ++++++++ .../celery_tasks/schedule_checker_task.py | 43 ++++++++----------- .../views/connector-edit-view.tsx | 14 +++--- .../views/indexing-configuration-view.tsx | 38 ++++++++++------ .../constants/connector-constants.ts | 36 ++++++---------- .../hooks/use-connector-dialog.ts | 8 +++- 10 files changed, 115 insertions(+), 126 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index bfc328389..de43f03d0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -15,10 +15,30 @@ _GMAIL_TYPES = [ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, ] +_token_encryption_cache: object | None = None + + +def _get_token_encryption(): + global _token_encryption_cache + if _token_encryption_cache is None: + from app.config import config + from app.utils.oauth_security import TokenEncryption + + if not config.SECRET_KEY: + raise RuntimeError("SECRET_KEY not configured for token decryption.") + _token_encryption_cache = TokenEncryption(config.SECRET_KEY) + return _token_encryption_cache + def _build_credentials(connector: SearchSourceConnector): - """Build Google OAuth Credentials from a Gmail connector's config.""" - if connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: + """Build Google OAuth Credentials from a connector's stored config. + + Handles both native OAuth connectors (with encrypted tokens) and + Composio-backed connectors. Shared by Gmail and Calendar tools. + """ + from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES + + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: from app.utils.google_credentials import build_composio_credentials cca_id = connector.config.get("composio_connected_account_id") @@ -28,12 +48,9 @@ def _build_credentials(connector: SearchSourceConnector): from google.oauth2.credentials import Credentials - from app.config import config - from app.utils.oauth_security import TokenEncryption - cfg = dict(connector.config) - if cfg.get("_token_encrypted") and config.SECRET_KEY: - enc = TokenEncryption(config.SECRET_KEY) + if cfg.get("_token_encrypted"): + enc = _get_token_encryption() for key in ("token", "refresh_token", "client_secret"): if cfg.get(key): cfg[key] = enc.decrypt_token(cfg[key]) diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index ad66775ef..a622b0efa 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -1,11 +1,11 @@ import logging -from datetime import datetime from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.agents.new_chat.tools.gmail.search_emails import _build_credentials from app.db import SearchSourceConnector, SearchSourceConnectorType logger = logging.getLogger(__name__) @@ -16,40 +16,6 @@ _CALENDAR_TYPES = [ ] -def _build_credentials(connector: SearchSourceConnector): - """Build Google OAuth Credentials from a Calendar connector's config.""" - if connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - raise ValueError("Composio connected account ID not found.") - return build_composio_credentials(cca_id) - - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - cfg = dict(connector.config) - if cfg.get("_token_encrypted") and config.SECRET_KEY: - enc = TokenEncryption(config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if cfg.get(key): - cfg[key] = enc.decrypt_token(cfg[key]) - - exp = (cfg.get("expiry") or "").replace("Z", "") - return Credentials( - token=cfg.get("token"), - refresh_token=cfg.get("refresh_token"), - token_uri=cfg.get("token_uri"), - client_id=cfg.get("client_id"), - client_secret=cfg.get("client_secret"), - scopes=cfg.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - - def create_search_calendar_events_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 0c06318ee..989894003 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -777,19 +777,9 @@ async def index_connector_content( # For non-calendar connectors, cap at today indexing_to = end_date if end_date else today_str - _LIVE_CONNECTOR_TYPES = { - SearchSourceConnectorType.SLACK_CONNECTOR, - SearchSourceConnectorType.TEAMS_CONNECTOR, - SearchSourceConnectorType.LINEAR_CONNECTOR, - SearchSourceConnectorType.JIRA_CONNECTOR, - SearchSourceConnectorType.CLICKUP_CONNECTOR, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.AIRTABLE_CONNECTOR, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.DISCORD_CONNECTOR, - SearchSourceConnectorType.LUMA_CONNECTOR, - } - if connector.connector_type in _LIVE_CONNECTOR_TYPES: + from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES + + if connector.connector_type in LIVE_CONNECTOR_TYPES: return { "message": ( f"{connector.connector_type.value} uses real-time agent tools; " diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index 13fe37832..a8abe4aa8 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = { } # Toolkits that support indexing (Phase 1: Google services only) -INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"} +INDEXABLE_TOOLKITS = {"googledrive"} # Mapping of toolkit IDs to connector types TOOLKIT_TO_CONNECTOR_TYPE = { diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 47a654465..49bc74d3d 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -16,6 +16,8 @@ from __future__ import annotations from dataclasses import dataclass, field +from app.db import SearchSourceConnectorType + @dataclass(frozen=True) class MCPServiceConfig: @@ -134,6 +136,21 @@ _CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = { svc.connector_type: svc for svc in MCP_SERVICES.values() } +LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({ + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnectorType.LINEAR_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CLICKUP_CONNECTOR, + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.AIRTABLE_CONNECTOR, + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnectorType.LUMA_CONNECTOR, +}) + def get_service(key: str) -> MCPServiceConfig | None: return MCP_SERVICES.get(key) diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 89010192f..373f04b48 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -59,9 +59,7 @@ async def _check_and_trigger_schedules(): index_crawled_urls_task, index_elasticsearch_documents_task, index_github_repos_task, - index_google_calendar_events_task, index_google_drive_files_task, - index_google_gmail_messages_task, index_notion_pages_task, ) @@ -73,34 +71,29 @@ async def _check_and_trigger_schedules(): SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task, SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task, SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, } - _LIVE_CONNECTOR_TYPES = { - SearchSourceConnectorType.SLACK_CONNECTOR, - SearchSourceConnectorType.TEAMS_CONNECTOR, - SearchSourceConnectorType.LINEAR_CONNECTOR, - SearchSourceConnectorType.JIRA_CONNECTOR, - SearchSourceConnectorType.CLICKUP_CONNECTOR, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.AIRTABLE_CONNECTOR, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.DISCORD_CONNECTOR, - SearchSourceConnectorType.LUMA_CONNECTOR, - } + from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES + + # Disable obsolete periodic indexing for live connectors in one batch. + live_disabled = [] + for connector in due_connectors: + if connector.connector_type in LIVE_CONNECTOR_TYPES: + connector.periodic_indexing_enabled = False + connector.next_scheduled_at = None + live_disabled.append(connector) + if live_disabled: + await session.commit() + for c in live_disabled: + logger.info( + "Disabled obsolete periodic indexing for live connector %s (%s)", + c.id, + c.connector_type.value, + ) # Trigger indexing for each due connector for connector in due_connectors: - if connector.connector_type in _LIVE_CONNECTOR_TYPES: - connector.periodic_indexing_enabled = False - connector.next_scheduled_at = None - await session.commit() - logger.info( - "Disabled obsolete periodic indexing for live connector %s (%s)", - connector.id, - connector.connector_type.value, - ) + if connector in live_disabled: continue # Primary guard: Redis lock indicates a task is currently running. diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index aa3c8d193..a69cf968f 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -236,8 +236,8 @@ export const ConnectorEditView: FC = ({

- {/* Quick Index Button - hidden when auth is expired */} - {connector.is_indexable && onQuickIndex && !isAuthExpired && ( + {/* Quick Index Button - hidden for live connectors and when auth is expired */} + {connector.is_indexable && !isLive && onQuickIndex && !isAuthExpired && ( - ) : !isMCPBacked ? ( + ) : !isLive ? (

- Configure when to start syncing your data + {isLive + ? "Your account is ready to use" + : "Configure when to start syncing your data"}

@@ -157,8 +161,8 @@ export const IndexingConfigurationView: FC = ({ )} - {/* Summary and sync settings - only shown for indexable connectors */} - {connector?.is_indexable && ( + {/* Summary and sync settings - hidden for live connectors */} + {connector?.is_indexable && !isLive && ( <> {/* AI Summary toggle */} @@ -209,8 +213,8 @@ export const IndexingConfigurationView: FC = ({ )} - {/* Info box - only shown for indexable connectors */} - {connector?.is_indexable && ( + {/* Info box - hidden for live connectors */} + {connector?.is_indexable && !isLive && (
@@ -238,14 +242,20 @@ export const IndexingConfigurationView: FC = ({ {/* Fixed Footer - Action buttons */}
- + {isLive ? ( + + ) : ( + + )}
); diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 1f324d53e..05f866d0f 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -13,7 +13,9 @@ export const LIVE_CONNECTOR_TYPES = new Set([ EnumConnectorName.DISCORD_CONNECTOR, EnumConnectorName.TEAMS_CONNECTOR, EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR, + EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, EnumConnectorName.GOOGLE_GMAIL_CONNECTOR, + EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR, EnumConnectorName.LUMA_CONNECTOR, ]); @@ -30,7 +32,7 @@ export const OAUTH_CONNECTORS = [ { id: "google-gmail-connector", title: "Gmail", - description: "Search and read your emails", + description: "Search, read, draft, and send emails", connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR, authEndpoint: "/api/v1/auth/google/gmail/connector/add/", selfHostedOnly: true, @@ -46,7 +48,7 @@ export const OAUTH_CONNECTORS = [ { id: "airtable-connector", title: "Airtable", - description: "Search, read, and manage records", + description: "Browse bases, tables, and records", connectorType: EnumConnectorName.AIRTABLE_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/", }, @@ -67,7 +69,7 @@ export const OAUTH_CONNECTORS = [ { id: "slack-connector", title: "Slack", - description: "Search, read, and send messages", + description: "Search and read channels and threads", connectorType: EnumConnectorName.SLACK_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/slack/connector/add/", }, @@ -116,7 +118,7 @@ export const OAUTH_CONNECTORS = [ { id: "clickup-connector", title: "ClickUp", - description: "Search, read, and manage tasks", + description: "Search and read tasks", connectorType: EnumConnectorName.CLICKUP_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/", }, @@ -155,7 +157,7 @@ export const OTHER_CONNECTORS = [ { id: "luma-connector", title: "Luma", - description: "Search and manage events", + description: "Browse, read, and create events", connectorType: EnumConnectorName.LUMA_CONNECTOR, }, { @@ -214,14 +216,14 @@ export const COMPOSIO_CONNECTORS = [ { id: "composio-gmail", title: "Gmail", - description: "Search through your emails via Composio", + description: "Search, read, draft, and send emails via Composio", connectorType: EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR, authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=gmail", }, { id: "composio-googlecalendar", title: "Google Calendar", - description: "Search through your events via Composio", + description: "Search and manage your events via Composio", connectorType: EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=googlecalendar", }, @@ -238,14 +240,14 @@ export const COMPOSIO_TOOLKITS = [ { id: "gmail", name: "Gmail", - description: "Search through your emails", - isIndexable: true, + description: "Search, read, draft, and send emails", + isIndexable: false, }, { id: "googlecalendar", name: "Google Calendar", - description: "Search through your events", - isIndexable: true, + description: "Search and manage your events", + isIndexable: false, }, { id: "slack", @@ -275,18 +277,6 @@ export interface AutoIndexConfig { } export const AUTO_INDEX_DEFAULTS: Record = { - [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of emails.", - }, - [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: { - daysBack: 90, - daysForward: 90, - frequencyMinutes: 1440, - syncDescription: "Syncing 90 days of past and upcoming events.", - }, [EnumConnectorName.NOTION_CONNECTOR]: { daysBack: 365, daysForward: 0, diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index 9f968e2a7..a8d395e5c 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -38,6 +38,7 @@ import { AUTO_INDEX_CONNECTOR_TYPES, AUTO_INDEX_DEFAULTS, COMPOSIO_CONNECTORS, + LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS, OTHER_CONNECTORS, } from "../constants/connector-constants"; @@ -317,7 +318,12 @@ export const useConnectorDialog = () => { newConnector.id ); - if ( + const isLiveConnector = LIVE_CONNECTOR_TYPES.has(oauthConnector.connectorType); + + if (isLiveConnector) { + toast.success(`${oauthConnector.title} connected successfully!`); + await refetchAllConnectors(); + } else if ( newConnector.is_indexable && AUTO_INDEX_CONNECTOR_TYPES.has(oauthConnector.connectorType) ) { From 16f47578d787e59dcdbc7f4aebc7597be16fc7b4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 08:03:32 +0200 Subject: [PATCH 102/299] Enhance MCP tool trust functionality to support OAuth-backed connectors and improve error handling in the UI. Refactor API calls to use baseApiService for consistency. --- .../routes/search_source_connectors_routes.py | 17 +++++++++---- .../tool-ui/generic-hitl-approval.tsx | 5 ++-- .../lib/apis/connectors-api.service.ts | 24 ++++--------------- 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 989894003..b8142c192 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -3105,13 +3105,18 @@ async def trust_mcp_tool( """Add a tool to the MCP connector's trusted (always-allow) list. Once trusted, the tool executes without HITL approval on subsequent calls. + Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors + (LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``. """ try: + from sqlalchemy import cast + from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB + result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.user_id == user.id, + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 ) ) connector = result.scalars().first() @@ -3156,13 +3161,17 @@ async def untrust_mcp_tool( """Remove a tool from the MCP connector's trusted list. The tool will require HITL approval again on subsequent calls. + Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors. """ try: + from sqlalchemy import cast + from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB + result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.user_id == user.id, + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 ) ) connector = result.scalars().first() diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index 809b76c38..d21f249ee 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -3,6 +3,7 @@ import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; import { CornerDownLeftIcon, Pen } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -116,8 +117,8 @@ function GenericApprovalCard({ if (phase !== "pending" || !isMCPTool) return; setProcessing(); onDecision({ type: "approve" }); - connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch((err) => { - console.error("Failed to trust MCP tool:", err); + connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch(() => { + toast.error("Failed to save 'Always Allow' preference. The tool will still require approval next time."); }); }, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]); diff --git a/surfsense_web/lib/apis/connectors-api.service.ts b/surfsense_web/lib/apis/connectors-api.service.ts index 3eaa767c5..f4137c787 100644 --- a/surfsense_web/lib/apis/connectors-api.service.ts +++ b/surfsense_web/lib/apis/connectors-api.service.ts @@ -414,16 +414,8 @@ class ConnectorsApiService { * Subsequent calls to this tool will skip HITL approval. */ trustMCPTool = async (connectorId: number, toolName: string): Promise => { - const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const token = - typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined; - await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/trust-tool`, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...(token ? { Authorization: `Bearer ${token}` } : {}), - }, - body: JSON.stringify({ tool_name: toolName }), + await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/trust-tool`, undefined, { + body: { tool_name: toolName }, }); }; @@ -431,16 +423,8 @@ class ConnectorsApiService { * Remove a tool from the MCP connector's "Always Allow" list. */ untrustMCPTool = async (connectorId: number, toolName: string): Promise => { - const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const token = - typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined; - await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/untrust-tool`, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...(token ? { Authorization: `Bearer ${token}` } : {}), - }, - body: JSON.stringify({ tool_name: toolName }), + await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/untrust-tool`, undefined, { + body: { tool_name: toolName }, }); }; } From e3172dc282dbd31115b0ac5696dba74bf5bfbad7 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 08:27:11 +0200 Subject: [PATCH 103/299] fix: reactive 401 recovery for live MCP connectors and unified reauth endpoints --- .../app/agents/new_chat/tools/mcp_tool.py | 480 +++++++++++++----- .../views/connector-edit-view.tsx | 18 +- .../constants/connector-constants.ts | 39 ++ .../views/connector-accounts-list-view.tsx | 37 +- 4 files changed, 396 insertions(+), 178 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 8f8e5007f..ddd65c7a7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -194,6 +194,31 @@ async def _create_mcp_tool_from_definition_http( input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema) + async def _do_mcp_call( + call_headers: dict[str, str], + call_kwargs: dict[str, Any], + ) -> str: + """Execute a single MCP HTTP call with the given headers.""" + async with ( + streamablehttp_client(url, headers=call_headers) as (read, write, _), + ClientSession(read, write) as session, + ): + await session.initialize() + response = await session.call_tool( + original_tool_name, arguments=call_kwargs, + ) + + result = [] + for content in response.content: + if hasattr(content, "text"): + result.append(content.text) + elif hasattr(content, "data"): + result.append(str(content.data)) + else: + result.append(str(content)) + + return "\n".join(result) if result else "" + async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" logger.debug("MCP HTTP tool '%s' called", exposed_name) @@ -218,31 +243,46 @@ async def _create_mcp_tool_from_definition_http( call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} try: - async with ( - streamablehttp_client(url, headers=headers) as (read, write, _), - ClientSession(read, write) as session, - ): - await session.initialize() - response = await session.call_tool( - original_tool_name, arguments=call_kwargs, + result_str = await _do_mcp_call(headers, call_kwargs) + logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) + return result_str + + except Exception as first_err: + if not _is_auth_error(first_err) or connector_id is None: + logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err) + return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}" + + logger.warning( + "MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s", + exposed_name, connector_id, + ) + fresh_headers = await _force_refresh_and_get_headers(connector_id) + if fresh_headers is None: + await _mark_connector_auth_expired(connector_id) + return ( + f"Error: MCP tool '{exposed_name}' authentication expired. " + "Please re-authenticate the connector in your settings." ) - result = [] - for content in response.content: - if hasattr(content, "text"): - result.append(content.text) - elif hasattr(content, "data"): - result.append(str(content.data)) - else: - result.append(str(content)) - - result_str = "\n".join(result) if result else "" - logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) + try: + result_str = await _do_mcp_call(fresh_headers, call_kwargs) + logger.info( + "MCP HTTP tool '%s' succeeded after 401 recovery", + exposed_name, + ) return result_str - - except Exception as e: - logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e) - return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}" + except Exception as retry_err: + logger.exception( + "MCP HTTP tool '%s' still failing after token refresh: %s", + exposed_name, retry_err, + ) + if _is_auth_error(retry_err): + await _mark_connector_auth_expired(connector_id) + return ( + f"Error: MCP tool '{exposed_name}' authentication expired. " + "Please re-authenticate the connector in your settings." + ) + return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}" tool = StructuredTool( name=exposed_name, @@ -365,66 +405,98 @@ async def _load_http_mcp_tools( allowed_set = set(allowed_tools) if allowed_tools else None - try: + async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]: + """Connect, initialize, and list tools from the MCP server.""" async with ( - streamablehttp_client(url, headers=headers) as (read, write, _), + streamablehttp_client(url, headers=disc_headers) as (read, write, _), ClientSession(read, write) as session, ): await session.initialize() - response = await session.list_tools() - tool_definitions = [] - for tool in response.tools: - tool_definitions.append( - { - "name": tool.name, - "description": tool.description or "", - "input_schema": tool.inputSchema - if hasattr(tool, "inputSchema") - else {}, - } - ) + return [ + { + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema + if hasattr(tool, "inputSchema") + else {}, + } + for tool in response.tools + ] - total_discovered = len(tool_definitions) + try: + tool_definitions = await _discover(headers) + except Exception as first_err: + if not _is_auth_error(first_err) or connector_id is None: + logger.exception( + "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", + url, connector_id, first_err, + ) + return tools - if allowed_set: - tool_definitions = [ - td for td in tool_definitions if td["name"] in allowed_set - ] - logger.info( - "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", - url, connector_id, len(tool_definitions), total_discovered, - ) - else: - logger.info( - "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", - total_discovered, url, connector_id, - ) - - for tool_def in tool_definitions: - try: - tool = await _create_mcp_tool_from_definition_http( - tool_def, - url, - headers, - connector_name=connector_name, - connector_id=connector_id, - trusted_tools=trusted_tools, - readonly_tools=readonly_tools, - tool_name_prefix=tool_name_prefix, - ) - tools.append(tool) - except Exception as e: - logger.exception( - "Failed to create HTTP tool '%s' from connector %d: %s", - tool_def.get("name"), connector_id, e, - ) - - except Exception as e: - logger.exception( - "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", - url, connector_id, e, + logger.warning( + "HTTP MCP discovery for connector %d got 401 — attempting token refresh", + connector_id, ) + fresh_headers = await _force_refresh_and_get_headers(connector_id) + if fresh_headers is None: + await _mark_connector_auth_expired(connector_id) + logger.error( + "HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired", + connector_id, + ) + return tools + + try: + tool_definitions = await _discover(fresh_headers) + headers = fresh_headers + logger.info( + "HTTP MCP discovery for connector %d succeeded after 401 recovery", + connector_id, + ) + except Exception as retry_err: + logger.exception( + "HTTP MCP discovery for connector %d still failing after refresh: %s", + connector_id, retry_err, + ) + if _is_auth_error(retry_err): + await _mark_connector_auth_expired(connector_id) + return tools + + total_discovered = len(tool_definitions) + + if allowed_set: + tool_definitions = [ + td for td in tool_definitions if td["name"] in allowed_set + ] + logger.info( + "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", + url, connector_id, len(tool_definitions), total_discovered, + ) + else: + logger.info( + "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", + total_discovered, url, connector_id, + ) + + for tool_def in tool_definitions: + try: + tool = await _create_mcp_tool_from_definition_http( + tool_def, + url, + headers, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, + readonly_tools=readonly_tools, + tool_name_prefix=tool_name_prefix, + ) + tools.append(tool) + except Exception as e: + logger.exception( + "Failed to create HTTP tool '%s' from connector %d: %s", + tool_def.get("name"), connector_id, e, + ) return tools @@ -476,6 +548,91 @@ def _inject_oauth_headers( return None +async def _refresh_connector_token( + session: AsyncSession, + connector: "SearchSourceConnector", +) -> str | None: + """Refresh the OAuth token for an MCP connector and persist the result. + + This is the shared core used by both proactive (pre-expiry) and reactive + (401 recovery) refresh paths. It handles: + - Decrypting the current refresh token / client secret + - Calling the token endpoint + - Encrypting and persisting the new tokens + - Clearing ``auth_expired`` if it was set + - Invalidating the MCP tools cache + + Returns the **plaintext** new access token on success, or ``None`` on + failure (no refresh token, IdP error, etc.). + """ + from datetime import UTC, datetime, timedelta + + from sqlalchemy.orm.attributes import flag_modified + + from app.services.mcp_oauth.discovery import refresh_access_token + + cfg = connector.config or {} + mcp_oauth = cfg.get("mcp_oauth", {}) + + refresh_token = mcp_oauth.get("refresh_token") + if not refresh_token: + logger.warning( + "MCP connector %s: no refresh_token available", + connector.id, + ) + return None + + enc = _get_token_enc() + decrypted_refresh = enc.decrypt_token(refresh_token) + decrypted_secret = ( + enc.decrypt_token(mcp_oauth["client_secret"]) + if mcp_oauth.get("client_secret") + else "" + ) + + token_json = await refresh_access_token( + token_endpoint=mcp_oauth["token_endpoint"], + refresh_token=decrypted_refresh, + client_id=mcp_oauth["client_id"], + client_secret=decrypted_secret, + ) + + new_access = token_json.get("access_token") + if not new_access: + logger.warning( + "MCP connector %s: token refresh returned no access_token", + connector.id, + ) + return None + + new_expires_at = None + if token_json.get("expires_in"): + new_expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + updated_oauth = dict(mcp_oauth) + updated_oauth["access_token"] = enc.encrypt_token(new_access) + if token_json.get("refresh_token"): + updated_oauth["refresh_token"] = enc.encrypt_token( + token_json["refresh_token"] + ) + updated_oauth["expires_at"] = ( + new_expires_at.isoformat() if new_expires_at else None + ) + + updated_cfg = {**cfg, "mcp_oauth": updated_oauth} + updated_cfg.pop("auth_expired", None) + connector.config = updated_cfg + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + invalidate_mcp_tools_cache(connector.search_space_id) + + return new_access + + async def _maybe_refresh_mcp_oauth_token( session: AsyncSession, connector: "SearchSourceConnector", @@ -504,73 +661,13 @@ async def _maybe_refresh_mcp_oauth_token( except (ValueError, TypeError): return server_config - refresh_token = mcp_oauth.get("refresh_token") - if not refresh_token: - logger.warning( - "MCP connector %s token expired but no refresh_token available", - connector.id, - ) - return server_config - try: - from app.services.mcp_oauth.discovery import refresh_access_token - - enc = _get_token_enc() - decrypted_refresh = enc.decrypt_token(refresh_token) - decrypted_secret = ( - enc.decrypt_token(mcp_oauth["client_secret"]) - if mcp_oauth.get("client_secret") - else "" - ) - - token_json = await refresh_access_token( - token_endpoint=mcp_oauth["token_endpoint"], - refresh_token=decrypted_refresh, - client_id=mcp_oauth["client_id"], - client_secret=decrypted_secret, - ) - - new_access = token_json.get("access_token") + new_access = await _refresh_connector_token(session, connector) if not new_access: - logger.warning( - "MCP connector %s token refresh returned no access_token", - connector.id, - ) return server_config - new_expires_at = None - if token_json.get("expires_in"): - new_expires_at = datetime.now(UTC) + timedelta( - seconds=int(token_json["expires_in"]) - ) + logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id) - updated_oauth = dict(mcp_oauth) - updated_oauth["access_token"] = enc.encrypt_token(new_access) - if token_json.get("refresh_token"): - updated_oauth["refresh_token"] = enc.encrypt_token( - token_json["refresh_token"] - ) - updated_oauth["expires_at"] = ( - new_expires_at.isoformat() if new_expires_at else None - ) - - from sqlalchemy.orm.attributes import flag_modified - - connector.config = { - **cfg, - "server_config": server_config, - "mcp_oauth": updated_oauth, - } - flag_modified(connector, "config") - await session.commit() - await session.refresh(connector) - - logger.info("Refreshed MCP OAuth token for connector %s", connector.id) - - # Invalidate cache so next call picks up the new token. - invalidate_mcp_tools_cache(connector.search_space_id) - - # Return server_config with the fresh token injected for immediate use. refreshed_config = dict(server_config) refreshed_config["headers"] = { **server_config.get("headers", {}), @@ -587,6 +684,117 @@ async def _maybe_refresh_mcp_oauth_token( return server_config +# --------------------------------------------------------------------------- +# Reactive 401 handling helpers +# --------------------------------------------------------------------------- + + +def _is_auth_error(exc: Exception) -> bool: + """Check if an exception indicates an HTTP 401 authentication failure.""" + try: + import httpx + + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code == 401 + except ImportError: + pass + err_str = str(exc).lower() + return "401" in err_str or "unauthorized" in err_str + + +async def _force_refresh_and_get_headers( + connector_id: int, +) -> dict[str, str] | None: + """Force-refresh OAuth token for a connector and return fresh HTTP headers. + + Opens a **new** DB session so this can be called from inside tool closures + that don't have access to the original session. + + Returns ``None`` when the connector is not OAuth-backed, has no + refresh token, or the refresh itself fails. + """ + from app.db import async_session_maker + + try: + async with async_session_maker() as session: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + ) + ) + connector = result.scalars().first() + if not connector: + return None + + cfg = connector.config or {} + if not cfg.get("mcp_oauth"): + return None + + server_config = cfg.get("server_config", {}) + + new_access = await _refresh_connector_token(session, connector) + if not new_access: + return None + + logger.info( + "Force-refreshed MCP OAuth token for connector %s (401 recovery)", + connector_id, + ) + return { + **server_config.get("headers", {}), + "Authorization": f"Bearer {new_access}", + } + + except Exception: + logger.warning( + "Failed to force-refresh MCP OAuth token for connector %s", + connector_id, + exc_info=True, + ) + return None + + +async def _mark_connector_auth_expired(connector_id: int) -> None: + """Set ``config.auth_expired = True`` so the frontend shows re-auth UI.""" + from app.db import async_session_maker + + try: + async with async_session_maker() as session: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + ) + ) + connector = result.scalars().first() + if not connector: + return + + cfg = dict(connector.config or {}) + if cfg.get("auth_expired"): + return + + cfg["auth_expired"] = True + connector.config = cfg + + from sqlalchemy.orm.attributes import flag_modified + + flag_modified(connector, "config") + await session.commit() + + logger.info( + "Marked MCP connector %s as auth_expired after unrecoverable 401", + connector_id, + ) + invalidate_mcp_tools_cache(connector.search_space_id) + + except Exception: + logger.warning( + "Failed to mark connector %s as auth_expired", + connector_id, + exc_info=True, + ) + + def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: """Invalidate cached MCP tools. diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index a69cf968f..16e7bd0d5 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -7,7 +7,6 @@ import { toast } from "sonner"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; -import { EnumConnectorName } from "@/contracts/enums/connector"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; @@ -16,23 +15,10 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants"; +import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; -const REAUTH_ENDPOINTS: Partial> = { - [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", - [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", - [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", - [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", - [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth", - [EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth", -}; - interface ConnectorEditViewProps { connector: SearchSourceConnector; startDate: Date | undefined; @@ -86,7 +72,7 @@ export const ConnectorEditView: FC = ({ }) => { const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom); const isAuthExpired = connector.config?.auth_expired === true; - const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type]; + const reauthEndpoint = getReauthEndpoint(connector); const [reauthing, setReauthing] = useState(false); const handleReauth = useCallback(async () => { diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 05f866d0f..621b71411 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -1,4 +1,5 @@ import { EnumConnectorName } from "@/contracts/enums/connector"; +import type { SearchSourceConnector } from "@/contracts/types/connector.types"; /** * Connectors that operate in real time (no background indexing). @@ -367,5 +368,43 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem }; } +// ============================================================================= +// REAUTH ENDPOINTS +// ============================================================================= + +/** + * Legacy (non-MCP) OAuth reauth endpoints, keyed by connector type. + * These are used for connectors that were NOT created via MCP OAuth. + */ +export const LEGACY_REAUTH_ENDPOINTS: Partial> = { + [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", + [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", + [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", + [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth", + [EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", + [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", + [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", + [EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth", + [EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth", + [EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth", + [EnumConnectorName.TEAMS_CONNECTOR]: "/api/v1/auth/teams/connector/reauth", + [EnumConnectorName.DISCORD_CONNECTOR]: "/api/v1/auth/discord/connector/reauth", +}; + +/** + * Resolve the reauth endpoint for a connector. + * + * MCP OAuth connectors (those with ``config.mcp_service``) dynamically build + * the URL from the service key. Legacy OAuth connectors fall back to the + * static ``LEGACY_REAUTH_ENDPOINTS`` map. + */ +export function getReauthEndpoint(connector: SearchSourceConnector): string | undefined { + const mcpService = connector.config?.mcp_service as string | undefined; + if (mcpService) { + return `/api/v1/auth/mcp/${mcpService}/connector/reauth`; + } + return LEGACY_REAUTH_ENDPOINTS[connector.connector_type]; +} + // Re-export IndexingConfigState from schemas for backward compatibility export type { IndexingConfigState } from "./connector-popup.schemas"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index b48b14ed2..a1ae96a40 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -13,25 +13,10 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; +import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { getConnectorDisplayName } from "../tabs/all-connectors-tab"; -const REAUTH_ENDPOINTS: Partial> = { - [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", - [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", - [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", - [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", - [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth", - [EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth", - [EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth", - [EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth", -}; - interface ConnectorAccountsListViewProps { connectorType: string; connectorTitle: string; @@ -68,16 +53,15 @@ export const ConnectorAccountsListView: FC = ({ const isEnabled = isConnectorEnabled(connectorType); const statusMessage = getConnectorStatusMessage(connectorType); - const reauthEndpoint = REAUTH_ENDPOINTS[connectorType]; - const handleReauth = useCallback( - async (connectorId: number) => { - if (!searchSpaceId || !reauthEndpoint) return; - setReauthingId(connectorId); + async (connector: SearchSourceConnector) => { + const endpoint = getReauthEndpoint(connector); + if (!searchSpaceId || !endpoint) return; + setReauthingId(connector.id); try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const url = new URL(`${backendUrl}${reauthEndpoint}`); - url.searchParams.set("connector_id", String(connectorId)); + const url = new URL(`${backendUrl}${endpoint}`); + url.searchParams.set("connector_id", String(connector.id)); url.searchParams.set("space_id", String(searchSpaceId)); url.searchParams.set("return_url", window.location.pathname); const response = await authenticatedFetch(url.toString()); @@ -99,7 +83,7 @@ export const ConnectorAccountsListView: FC = ({ setReauthingId(null); } }, - [searchSpaceId, reauthEndpoint] + [searchSpaceId] ); // Filter connectors to only show those of this type @@ -200,7 +184,8 @@ export const ConnectorAccountsListView: FC = ({
{typeConnectors.map((connector) => { const isIndexing = indexingConnectorIds.has(connector.id); - const isAuthExpired = !!reauthEndpoint && connector.config?.auth_expired === true; + const connectorReauthEndpoint = getReauthEndpoint(connector); + const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true; return (
= ({
) : (
- {typeConnectors.map((connector) => { - const isIndexing = indexingConnectorIds.has(connector.id); - const connectorReauthEndpoint = getReauthEndpoint(connector); - const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true; + {typeConnectors.map((connector) => { + const isIndexing = indexingConnectorIds.has(connector.id); + const connectorReauthEndpoint = getReauthEndpoint(connector); + const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true; + const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config); return (
= ({ Syncing

- ) : !isLiveConnector(connector.connector_type) ? ( + ) : !isLive ? (

{connector.last_indexed_at ? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}` @@ -224,28 +225,73 @@ export const ConnectorAccountsListView: FC = ({

) : null}
- {isAuthExpired ? ( - + {isAuthExpired ? ( + + ) : isLive && onDisconnect ? ( + confirmDisconnectId === connector.id ? ( +
+ + +
) : ( - )} + ) + ) : ( + + )}
); })} From 9bb117ffa7542870296fd5584f3a8fce3a7b4abd Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 08:51:31 +0200 Subject: [PATCH 106/299] feat: skip edit view for live connectors, disconnect directly from accounts list --- .../assistant-ui/connector-popup.tsx | 43 ++++++++++--------- .../views/connector-edit-view.tsx | 6 +-- .../hooks/use-connector-dialog.ts | 20 +++++++++ 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 84361e25b..66333a9ef 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -123,8 +123,9 @@ export const ConnectorIndicator = forwardRef ) : viewingMCPList ? ( - + handleDisconnectFromList(connector, () => refreshConnectors())} + onAddAccount={handleAddNewMCPFromList} + addButtonText="Add New MCP Server" + /> ) : viewingAccountsType ? ( - { + handleDisconnectFromList(connector, () => refreshConnectors())} + onAddAccount={() => { // Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS const oauthConnector = OAUTH_CONNECTORS.find( diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index 16e7bd0d5..44461c351 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -17,6 +17,7 @@ import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; +import { MCPServiceConfig } from "../components/mcp-service-config"; import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; interface ConnectorEditViewProps { @@ -110,10 +111,7 @@ export const ConnectorEditView: FC = ({ // Get connector-specific config component (MCP-backed connectors use a generic view) const ConnectorConfigComponent = useMemo(() => { - if (isMCPBacked) { - const { MCPServiceConfig } = require("../components/mcp-service-config"); - return MCPServiceConfig as FC; - } + if (isMCPBacked) return MCPServiceConfig; return getConnectorConfigComponent(connector.connector_type); }, [connector.connector_type, isMCPBacked]); const [isScrolled, setIsScrolled] = useState(false); diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index a8d395e5c..a9223fee5 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -1311,6 +1311,25 @@ export const useConnectorDialog = () => { [editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen] ); + const handleDisconnectFromList = useCallback( + async (connector: SearchSourceConnector, refreshConnectors: () => void) => { + if (!searchSpaceId) return; + try { + await deleteConnector({ id: connector.id }); + trackConnectorDeleted(Number(searchSpaceId), connector.connector_type, connector.id); + toast.success(`${connector.name} disconnected successfully`); + refreshConnectors(); + queryClient.invalidateQueries({ + queryKey: cacheKeys.logs.summary(Number(searchSpaceId)), + }); + } catch (error) { + console.error("Error disconnecting connector:", error); + toast.error("Failed to disconnect connector"); + } + }, + [searchSpaceId, deleteConnector] + ); + // Handle quick index (index with selected date range, or backend defaults if none selected) const handleQuickIndexConnector = useCallback( async ( @@ -1484,6 +1503,7 @@ export const useConnectorDialog = () => { handleStartEdit, handleSaveConnector, handleDisconnectConnector, + handleDisconnectFromList, handleBackFromEdit, handleBackFromConnect, handleBackFromYouTube, From 2eb0ff9e5e108760b5d846590450802ff6325022 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 08:57:56 +0200 Subject: [PATCH 107/299] feat: add reauthentication endpoints for Linear and JIRA connectors --- .../connector-popup/constants/connector-constants.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 621b71411..2ee811c19 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -377,6 +377,8 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem * These are used for connectors that were NOT created via MCP OAuth. */ export const LEGACY_REAUTH_ENDPOINTS: Partial> = { + [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", + [EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth", [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", From cf7c14cf44b9887d88b28c6718221c2da4a1fb9d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 09:27:03 +0200 Subject: [PATCH 108/299] fix: mark connector auth_expired on token decryption failure --- surfsense_backend/app/agents/new_chat/tools/mcp_tool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 7909657e0..b0dcd72b6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -895,6 +895,7 @@ async def load_mcp_tools( "Skipping MCP connector %d — OAuth token decryption failed", connector.id, ) + await _mark_connector_auth_expired(connector.id) continue trusted_tools = cfg.get("trusted_tools", []) From 1712f454f82f35d228d8765341617346673c2a48 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 09:45:25 +0200 Subject: [PATCH 109/299] fix: add spinner loading state to MCP test connection button --- .../connect-forms/components/mcp-connect-form.tsx | 11 +++++++++-- .../connector-configs/components/mcp-config.tsx | 11 +++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx index 58d365128..fc9812240 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx @@ -1,6 +1,6 @@ "use client"; -import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react"; +import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react"; import { type FC, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; @@ -212,7 +212,14 @@ export const MCPConnectForm: FC = ({ onSubmit, isSubmitting }) variant="secondary" className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > - {isTesting ? "Testing Connection" : "Test Connection"} + {isTesting ? ( + <> + + Testing Connection... + + ) : ( + "Test Connection" + )}
diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx index ca997a9ba..d6f60e824 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx @@ -1,6 +1,6 @@ "use client"; -import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react"; +import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react"; import type { FC } from "react"; import { useCallback, useEffect, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; @@ -217,7 +217,14 @@ export const MCPConfig: FC = ({ connector, onConfigChange, onNam variant="secondary" className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > - {isTesting ? "Testing Connection" : "Test Connection"} + {isTesting ? ( + <> + + Testing Connection... + + ) : ( + "Test Connection" + )}
From 45b72de481f5d3a1645de6545b39d23d4b776a7f Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 11:30:58 +0200 Subject: [PATCH 110/299] fix: robust generic MCP tool routing, retry, and empty-schema handling --- .../app/agents/new_chat/chat_deepagent.py | 16 +++ .../app/agents/new_chat/system_prompt.py | 35 +++++ .../app/agents/new_chat/tools/mcp_client.py | 58 +++++--- .../app/agents/new_chat/tools/mcp_tool.py | 124 ++++++++++++++---- 4 files changed, 191 insertions(+), 42 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 4b204ffa9..89aa13620 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -314,6 +314,20 @@ async def create_surfsense_deep_agent( _t0 = time.perf_counter() _enabled_tool_names = {t.name for t in tools} _user_disabled_tool_names = set(disabled_tools) if disabled_tools else set() + + # Collect generic MCP connector info so the system prompt can route queries + # to their tools instead of falling back to "not in knowledge base". + _mcp_connector_tools: dict[str, list[str]] = {} + for t in tools: + meta = getattr(t, "metadata", None) or {} + if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"): + _mcp_connector_tools.setdefault( + meta["mcp_connector_name"], [], + ).append(t.name) + + if _mcp_connector_tools: + _perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools) + if agent_config is not None: system_prompt = build_configurable_system_prompt( custom_system_instructions=agent_config.system_instructions, @@ -322,12 +336,14 @@ async def create_surfsense_deep_agent( thread_visibility=thread_visibility, enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, + mcp_connector_tools=_mcp_connector_tools, ) else: system_prompt = build_surfsense_system_prompt( thread_visibility=thread_visibility, enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, + mcp_connector_tools=_mcp_connector_tools, ) _perf_log.info( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 3182735d9..e77132182 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format """ +def _build_mcp_routing_block( + mcp_connector_tools: dict[str, list[str]] | None, +) -> str: + """Build an additional tool routing block for generic MCP connectors. + + When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know + those tools exist and should be called directly — not searched in the + knowledge base. + """ + if not mcp_connector_tools: + return "" + + lines = [ + "\n", + "You also have direct tools from these user-connected MCP servers.", + "Their data is NEVER in the knowledge base — call their tools directly.", + "", + ] + for server_name, tool_names in mcp_connector_tools.items(): + lines.append(f"- {server_name} → {', '.join(tool_names)}") + lines.append("\n") + return "\n".join(lines) + + def build_surfsense_system_prompt( today: datetime | None = None, thread_visibility: ChatVisibility | None = None, enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, ) -> str: """ Build the SurfSense system prompt with default settings. @@ -834,6 +859,9 @@ def build_surfsense_system_prompt( thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. + mcp_connector_tools: Mapping of MCP server display name → list of tool names + for generic MCP connectors. Injected into the system prompt so the LLM + knows to call these tools directly. Returns: Complete system prompt string @@ -841,6 +869,7 @@ def build_surfsense_system_prompt( visibility = thread_visibility or ChatVisibility.PRIVATE system_instructions = _get_system_instructions(visibility, today) + system_instructions += _build_mcp_routing_block(mcp_connector_tools) tools_instructions = _get_tools_instructions( visibility, enabled_tool_names, disabled_tool_names ) @@ -856,6 +885,7 @@ def build_configurable_system_prompt( thread_visibility: ChatVisibility | None = None, enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, ) -> str: """ Build a configurable SurfSense system prompt based on NewLLMConfig settings. @@ -877,6 +907,9 @@ def build_configurable_system_prompt( thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. + mcp_connector_tools: Mapping of MCP server display name → list of tool names + for generic MCP connectors. Injected into the system prompt so the LLM + knows to call these tools directly. Returns: Complete system prompt string @@ -894,6 +927,8 @@ def build_configurable_system_prompt( else: system_instructions = "" + system_instructions += _build_mcp_routing_block(mcp_connector_tools) + # Tools instructions: only include enabled tools, note disabled ones tools_instructions = _get_tools_instructions( thread_visibility, enabled_tool_names, disabled_tool_names diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py index 44c48344c..b46ddbcc5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -45,6 +45,18 @@ class MCPClient: async def connect(self, max_retries: int = MAX_RETRIES): """Connect to the MCP server and manage its lifecycle. + Retries only apply to the **connection** phase (spawning the process, + initialising the session). Once the session is yielded to the caller, + any exception raised by the caller propagates normally -- the context + manager will NOT retry after ``yield``. + + Previous implementation wrapped both connection AND yield inside the + retry loop. Because ``@asynccontextmanager`` only allows a single + ``yield``, a failure after yield caused the generator to attempt a + second yield on retry, triggering + ``RuntimeError("generator didn't stop after athrow()")`` and orphaning + the stdio subprocess. + Args: max_retries: Maximum number of connection retry attempts @@ -57,26 +69,22 @@ class MCPClient: """ last_error = None delay = RETRY_DELAY + connected = False for attempt in range(max_retries): try: - # Merge env vars with current environment server_env = os.environ.copy() server_env.update(self.env) - # Create server parameters with env server_params = StdioServerParameters( command=self.command, args=self.args, env=server_env ) - # Spawn server process and create session - # Note: Cannot combine these context managers because ClientSession - # needs the read/write streams from stdio_client async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 async with ClientSession(read, write) as session: - # Initialize the connection await session.initialize() self.session = session + connected = True if attempt > 0: logger.info( @@ -91,10 +99,16 @@ class MCPClient: self.command, " ".join(self.args), ) - yield session - return # Success, exit retry loop + try: + yield session + finally: + self.session = None + return except Exception as e: + self.session = None + if connected: + raise last_error = e if attempt < max_retries - 1: logger.warning( @@ -105,7 +119,7 @@ class MCPClient: delay, ) await asyncio.sleep(delay) - delay *= RETRY_BACKOFF # Exponential backoff + delay *= RETRY_BACKOFF else: logger.error( "Failed to connect to MCP server after %d attempts: %s", @@ -113,10 +127,7 @@ class MCPClient: e, exc_info=True, ) - finally: - self.session = None - # All retries exhausted error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts" if last_error: error_msg += f": {last_error}" @@ -161,12 +172,18 @@ class MCPClient: logger.error("Failed to list tools from MCP server: %s", e, exc_info=True) raise - async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + async def call_tool( + self, + tool_name: str, + arguments: dict[str, Any], + timeout: float = 60.0, + ) -> Any: """Call a tool on the MCP server. Args: tool_name: Name of the tool to call arguments: Arguments to pass to the tool + timeout: Maximum seconds to wait for the tool to respond Returns: Tool execution result @@ -185,10 +202,11 @@ class MCPClient: "Calling MCP tool '%s' with arguments: %s", tool_name, arguments ) - # Call tools/call RPC method - response = await self.session.call_tool(tool_name, arguments=arguments) + response = await asyncio.wait_for( + self.session.call_tool(tool_name, arguments=arguments), + timeout=timeout, + ) - # Extract content from response result = [] for content in response.content: if hasattr(content, "text"): @@ -202,15 +220,17 @@ class MCPClient: logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) return result_str + except asyncio.TimeoutError: + logger.error( + "MCP tool '%s' timed out after %.0fs", tool_name, timeout + ) + return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s" except RuntimeError as e: - # Handle validation errors from MCP server responses - # Some MCP servers (like server-memory) return extra fields not in their schema if "Invalid structured content" in str(e): logger.warning( "MCP server returned data not matching its schema, but continuing: %s", e, ) - # Try to extract result from error message or return a success message return "Operation completed (server returned unexpected format)" raise except (ValueError, TypeError, AttributeError, KeyError) as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index b0dcd72b6..dfee24516 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from langchain_core.tools import StructuredTool from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, create_model from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession @@ -43,6 +43,8 @@ logger = logging.getLogger(__name__) _MCP_CACHE_TTL_SECONDS = 300 # 5 minutes _MCP_CACHE_MAX_SIZE = 50 _MCP_DISCOVERY_TIMEOUT_SECONDS = 30 +_TOOL_CALL_MAX_RETRIES = 3 +_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt _mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} @@ -64,7 +66,18 @@ def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], ) -> type[BaseModel]: - """Create a Pydantic model from MCP tool's JSON schema.""" + """Create a Pydantic model from MCP tool's JSON schema. + + Models always allow extra fields (``extra="allow"``) so that parameters + missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema`` + producing an empty ``$schema``-only object) can still be forwarded to the + MCP server. + + When the schema declares **no** properties, a synthetic ``input_data`` + field of type ``dict`` is injected so the LLM has a visible parameter to + populate. The caller should unpack ``input_data`` before forwarding to + the MCP server (see ``_unpack_synthetic_input_data``). + """ properties = input_schema.get("properties", {}) required_fields = input_schema.get("required", []) @@ -84,8 +97,35 @@ def _create_dynamic_input_model_from_schema( Field(None, description=param_description), ) + if not properties: + field_definitions["input_data"] = ( + dict[str, Any] | None, + Field( + None, + description=( + "Arguments to pass to this tool as a JSON object. " + "Infer sensible key names from the tool name and description " + "(e.g. {\"search\": \"my query\"} for a search tool)." + ), + ), + ) + model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" - return create_model(model_name, **field_definitions) + model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions) + return model + + +def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]: + """Unpack the synthetic ``input_data`` field into top-level kwargs. + + When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema`` + adds a catch-all ``input_data: dict`` field. This helper merges that dict + back into the top-level kwargs so the MCP server receives flat arguments. + """ + input_data = kwargs.pop("input_data", None) + if isinstance(input_data, dict): + kwargs.update(input_data) + return kwargs async def _create_mcp_tool_from_definition_stdio( @@ -103,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio( ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") - tool_description = tool_def.get("description", "No description provided") + raw_description = tool_def.get("description", "No description provided") + tool_description = ( + f"[MCP server: {connector_name}] {raw_description}" + if connector_name + else raw_description + ) input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema) @@ -121,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio( params=kwargs, context={ "mcp_server": connector_name, - "tool_description": tool_description, + "tool_description": raw_description, "mcp_transport": "stdio", "mcp_connector_id": connector_id, }, @@ -129,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) - try: - async with mcp_client.connect(): - result = await mcp_client.call_tool(tool_name, call_kwargs) - return str(result) - except RuntimeError as e: - logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e) - return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}" - except Exception as e: - logger.exception("MCP tool '%s' execution failed: %s", tool_name, e) - return f"Error: MCP tool '{tool_name}' execution failed: {e!s}" + last_error: Exception | None = None + for attempt in range(_TOOL_CALL_MAX_RETRIES): + try: + async with mcp_client.connect(): + result = await mcp_client.call_tool(tool_name, call_kwargs) + return str(result) + except Exception as e: + last_error = e + if attempt < _TOOL_CALL_MAX_RETRIES - 1: + delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt) + logger.warning( + "MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...", + tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay, + ) + await asyncio.sleep(delay) + else: + logger.error( + "MCP tool '%s' failed after %d attempts: %s", + tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True, + ) + + return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}" tool = StructuredTool( name=tool_name, @@ -150,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio( metadata={ "mcp_input_schema": input_schema, "mcp_transport": "stdio", + "mcp_connector_name": connector_name or None, + "mcp_is_generic": True, "hitl": True, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), }, @@ -169,6 +230,7 @@ async def _create_mcp_tool_from_definition_http( trusted_tools: list[str] | None = None, readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, + is_generic_mcp: bool = False, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). @@ -180,7 +242,7 @@ async def _create_mcp_tool_from_definition_http( but the actual MCP ``call_tool`` still uses the original name. """ original_tool_name = tool_def.get("name", "unnamed_tool") - tool_description = tool_def.get("description", "No description provided") + raw_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) is_readonly = readonly_tools is not None and original_tool_name in readonly_tools @@ -190,7 +252,11 @@ async def _create_mcp_tool_from_definition_http( else original_tool_name ) if tool_name_prefix: - tool_description = f"[Account: {connector_name}] {tool_description}" + tool_description = f"[Account: {connector_name}] {raw_description}" + elif is_generic_mcp and connector_name: + tool_description = f"[MCP server: {connector_name}] {raw_description}" + else: + tool_description = raw_description logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema) @@ -199,6 +265,7 @@ async def _create_mcp_tool_from_definition_http( async def _do_mcp_call( call_headers: dict[str, str], call_kwargs: dict[str, Any], + timeout: float = 60.0, ) -> str: """Execute a single MCP HTTP call with the given headers.""" async with ( @@ -206,8 +273,9 @@ async def _create_mcp_tool_from_definition_http( ClientSession(read, write) as session, ): await session.initialize() - response = await session.call_tool( - original_tool_name, arguments=call_kwargs, + response = await asyncio.wait_for( + session.call_tool(original_tool_name, arguments=call_kwargs), + timeout=timeout, ) result = [] @@ -226,7 +294,9 @@ async def _create_mcp_tool_from_definition_http( logger.debug("MCP HTTP tool '%s' called", exposed_name) if is_readonly: - call_kwargs = {k: v for k, v in kwargs.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in kwargs.items() if v is not None} + ) else: hitl_result = request_approval( action_type="mcp_tool_call", @@ -234,7 +304,7 @@ async def _create_mcp_tool_from_definition_http( params=kwargs, context={ "mcp_server": connector_name, - "tool_description": tool_description, + "tool_description": raw_description, "mcp_transport": "http", "mcp_connector_id": connector_id, }, @@ -242,7 +312,9 @@ async def _create_mcp_tool_from_definition_http( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) try: result_str = await _do_mcp_call(headers, call_kwargs) @@ -295,6 +367,8 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, + "mcp_connector_name": connector_name or None, + "mcp_is_generic": is_generic_mcp, "hitl": not is_readonly, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), "mcp_original_tool_name": original_tool_name, @@ -376,6 +450,7 @@ async def _load_http_mcp_tools( allowed_tools: list[str] | None = None, readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, + is_generic_mcp: bool = False, ) -> list[StructuredTool]: """Load tools from an HTTP-based MCP server. @@ -492,6 +567,7 @@ async def _load_http_mcp_tools( trusted_tools=trusted_tools, readonly_tools=readonly_tools, tool_name_prefix=tool_name_prefix, + is_generic_mcp=is_generic_mcp, ) tools.append(tool) except Exception as e: @@ -928,6 +1004,7 @@ async def load_mcp_tools( "readonly_tools": readonly_tools, "tool_name_prefix": tool_name_prefix, "transport": server_config.get("transport", "stdio"), + "is_generic_mcp": svc_cfg is None, }) except Exception as e: @@ -948,6 +1025,7 @@ async def load_mcp_tools( allowed_tools=task["allowed_tools"], readonly_tools=task["readonly_tools"], tool_name_prefix=task["tool_name_prefix"], + is_generic_mcp=task.get("is_generic_mcp", False), ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) From 749116e830122dcd4f6e3dbaf1a1e29d8ea6b726 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:02:58 +0530 Subject: [PATCH 111/299] feat(new-chat): add filesystem backend interfaces and selection helpers --- .../agents/new_chat/filesystem_backends.py | 38 +++++++++++++++++++ .../agents/new_chat/filesystem_selection.py | 33 ++++++++++++++++ .../agents/new_chat/middleware/__init__.py | 4 ++ 3 files changed, 75 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/filesystem_backends.py create mode 100644 surfsense_backend/app/agents/new_chat/filesystem_selection.py diff --git a/surfsense_backend/app/agents/new_chat/filesystem_backends.py b/surfsense_backend/app/agents/new_chat/filesystem_backends.py new file mode 100644 index 000000000..8af7e8558 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/filesystem_backends.py @@ -0,0 +1,38 @@ +"""Filesystem backend resolver for cloud and desktop-local modes.""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import lru_cache + +from deepagents.backends.state import StateBackend +from langgraph.prebuilt.tool_node import ToolRuntime + +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection +from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend + + +@lru_cache(maxsize=64) +def _cached_local_backend(root_path: str) -> LocalFolderBackend: + return LocalFolderBackend(root_path) + + +def build_backend_resolver( + selection: FilesystemSelection, +) -> Callable[[ToolRuntime], StateBackend | LocalFolderBackend]: + """Create deepagents backend resolver for the selected filesystem mode.""" + + if ( + selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER + and selection.local_root_path is not None + ): + + def _resolve_local(_runtime: ToolRuntime) -> LocalFolderBackend: + return _cached_local_backend(selection.local_root_path or "") + + return _resolve_local + + def _resolve_cloud(runtime: ToolRuntime) -> StateBackend: + return StateBackend(runtime) + + return _resolve_cloud diff --git a/surfsense_backend/app/agents/new_chat/filesystem_selection.py b/surfsense_backend/app/agents/new_chat/filesystem_selection.py new file mode 100644 index 000000000..3094a0b29 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/filesystem_selection.py @@ -0,0 +1,33 @@ +"""Filesystem mode contracts and selection helpers for chat sessions.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum + + +class FilesystemMode(StrEnum): + """Supported filesystem backends for agent tool execution.""" + + CLOUD = "cloud" + DESKTOP_LOCAL_FOLDER = "desktop_local_folder" + + +class ClientPlatform(StrEnum): + """Client runtime reported by the caller.""" + + WEB = "web" + DESKTOP = "desktop" + + +@dataclass(slots=True) +class FilesystemSelection: + """Resolved filesystem selection for a single chat request.""" + + mode: FilesystemMode = FilesystemMode.CLOUD + client_platform: ClientPlatform = ClientPlatform.WEB + local_root_path: str | None = None + + @property + def is_local_mode(self) -> bool: + return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 1f6b12852..5a24b2f9e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -6,6 +6,9 @@ from app.agents.new_chat.middleware.dedup_tool_calls import ( from app.agents.new_chat.middleware.filesystem import ( SurfSenseFilesystemMiddleware, ) +from app.agents.new_chat.middleware.file_intent import ( + FileIntentMiddleware, +) from app.agents.new_chat.middleware.knowledge_search import ( KnowledgeBaseSearchMiddleware, ) @@ -15,6 +18,7 @@ from app.agents.new_chat.middleware.memory_injection import ( __all__ = [ "DedupHITLToolCallsMiddleware", + "FileIntentMiddleware", "KnowledgeBaseSearchMiddleware", "MemoryInjectionMiddleware", "SurfSenseFilesystemMiddleware", From 15a9e8b085f36ceb2065b32ed8f3f0238878617f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:03:32 +0530 Subject: [PATCH 112/299] feat(middleware): detect file intent in chat messages --- .../agents/new_chat/middleware/file_intent.py | 253 ++++++++++++++++++ .../middleware/test_file_intent_middleware.py | 116 ++++++++ 2 files changed, 369 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/middleware/file_intent.py create mode 100644 surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py new file mode 100644 index 000000000..e264a939c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py @@ -0,0 +1,253 @@ +"""Semantic file-intent routing middleware for new chat turns. + +This middleware classifies the latest human turn into a small intent set: +- chat_only +- file_write +- file_read + +For ``file_write`` turns it injects a strict system contract so the model +uses filesystem tools before claiming success, and provides a deterministic +fallback path when no filename is specified by the user. +""" + +from __future__ import annotations + +import json +import logging +import re +from datetime import UTC, datetime +from enum import StrEnum +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langgraph.runtime import Runtime +from pydantic import BaseModel, Field, ValidationError + +logger = logging.getLogger(__name__) + + +class FileOperationIntent(StrEnum): + CHAT_ONLY = "chat_only" + FILE_WRITE = "file_write" + FILE_READ = "file_read" + + +class FileIntentPlan(BaseModel): + intent: FileOperationIntent = Field( + description="Primary user intent for this turn." + ) + confidence: float = Field( + ge=0.0, + le=1.0, + default=0.5, + description="Model confidence in the selected intent.", + ) + suggested_filename: str | None = Field( + default=None, + description="Optional filename (e.g. notes.md) inferred from user request.", + ) + + +def _extract_text_from_message(message: BaseMessage) -> str: + content = getattr(message, "content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + return "\n".join(part for part in parts if part) + return str(content) + + +def _extract_json_payload(text: str) -> str: + stripped = text.strip() + fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL) + if fenced: + return fenced.group(1) + start = stripped.find("{") + end = stripped.rfind("}") + if start != -1 and end != -1 and end > start: + return stripped[start : end + 1] + return stripped + + +def _sanitize_filename(value: str) -> str: + name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() + name = re.sub(r"\s+", "-", name) + name = name.strip("._-") + if not name: + name = "note" + if len(name) > 80: + name = name[:80].rstrip("-_.") + return name + + +def _infer_text_file_extension(user_text: str) -> str: + lowered = user_text.lower() + if any(token in lowered for token in ("json", ".json")): + return ".json" + if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")): + return ".yaml" + if any(token in lowered for token in ("csv", ".csv")): + return ".csv" + if any(token in lowered for token in ("python", ".py")): + return ".py" + if any(token in lowered for token in ("typescript", ".ts", ".tsx")): + return ".ts" + if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")): + return ".js" + if any(token in lowered for token in ("html", ".html")): + return ".html" + if any(token in lowered for token in ("css", ".css")): + return ".css" + if any(token in lowered for token in ("sql", ".sql")): + return ".sql" + if any(token in lowered for token in ("toml", ".toml")): + return ".toml" + if any(token in lowered for token in ("ini", ".ini")): + return ".ini" + if any(token in lowered for token in ("xml", ".xml")): + return ".xml" + if any(token in lowered for token in ("markdown", ".md", "readme")): + return ".md" + return ".md" + + +def _fallback_path(suggested_filename: str | None, *, user_text: str) -> str: + default_extension = _infer_text_file_extension(user_text) + if suggested_filename: + sanitized = _sanitize_filename(suggested_filename) + if sanitized.lower().endswith(".txt"): + sanitized = f"{sanitized[:-4]}.md" + if "." not in sanitized: + sanitized = f"{sanitized}{default_extension}" + return f"/{sanitized}" + return f"/notes{default_extension}" + + +def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str: + return ( + "Classify the latest user request into a filesystem intent for an AI agent.\n" + "Return JSON only with this exact schema:\n" + '{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null"}\n\n' + "Rules:\n" + "- Use semantic intent, not literal keywords.\n" + "- file_write: user asks to create/save/write/update/edit content as a file.\n" + "- file_read: user asks to open/read/list/search existing files.\n" + "- chat_only: conversational/analysis responses without required file operations.\n" + "- For file_write, choose a concise semantic suggested_filename and match the requested format.\n" + "- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n" + "- Do not use .txt; prefer .md for generic text notes.\n" + "- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n" + "- Never include markdown or explanation.\n\n" + f"Recent conversation:\n{recent_conversation or '(none)'}\n\n" + f"Latest user message:\n{user_text}" + ) + + +def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str: + rows: list[str] = [] + for msg in messages[-max_messages:]: + role = "user" if isinstance(msg, HumanMessage) else "assistant" + text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip() + if text: + rows.append(f"{role}: {text[:280]}") + return "\n".join(rows) + + +class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Classify file intent and inject a strict file-write contract.""" + + tools = () + + def __init__(self, *, llm: BaseChatModel | None = None) -> None: + self.llm = llm + + async def _classify_intent( + self, *, messages: list[BaseMessage], user_text: str + ) -> FileIntentPlan: + if self.llm is None: + return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) + + prompt = _build_classifier_prompt( + recent_conversation=_build_recent_conversation(messages), + user_text=user_text, + ) + try: + response = await self.llm.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal"]}, + ) + payload = json.loads(_extract_json_payload(_extract_text_from_message(response))) + plan = FileIntentPlan.model_validate(payload) + return plan + except (json.JSONDecodeError, ValidationError, ValueError) as exc: + logger.warning("File intent classifier returned invalid output: %s", exc) + except Exception as exc: # pragma: no cover - defensive fallback + logger.warning("File intent classifier failed: %s", exc) + + return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + messages = state.get("messages") or [] + if not messages: + return None + + last_human: HumanMessage | None = None + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + last_human = msg + break + if last_human is None: + return None + + user_text = _extract_text_from_message(last_human).strip() + if not user_text: + return None + + plan = await self._classify_intent(messages=messages, user_text=user_text) + suggested_path = _fallback_path(plan.suggested_filename, user_text=user_text) + contract = { + "intent": plan.intent.value, + "confidence": plan.confidence, + "suggested_path": suggested_path, + "timestamp": datetime.now(UTC).isoformat(), + "turn_id": state.get("turn_id", ""), + } + + if plan.intent != FileOperationIntent.FILE_WRITE: + return {"file_operation_contract": contract} + + contract_msg = SystemMessage( + content=( + "\n" + "This turn intent is file_write.\n" + f"Suggested default path: {suggested_path}\n" + "Rules:\n" + "- You MUST call write_file or edit_file before claiming success.\n" + "- If no path is provided by the user, use the suggested default path.\n" + "- Do not claim a file was created/updated unless tool output confirms it.\n" + "- If the write/edit fails, clearly report failure instead of success.\n" + "- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n" + "- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n" + "" + ) + ) + + # Insert just before the latest human turn so it applies to this request. + new_messages = list(messages) + insert_at = max(len(new_messages) - 1, 0) + new_messages.insert(insert_at, contract_msg) + return {"messages": new_messages, "file_operation_contract": contract} + diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py new file mode 100644 index 000000000..68876dfeb --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py @@ -0,0 +1,116 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from app.agents.new_chat.middleware.file_intent import ( + FileIntentMiddleware, + FileOperationIntent, +) + +pytestmark = pytest.mark.unit + + +class _FakeLLM: + def __init__(self, response_text: str): + self._response_text = response_text + + async def ainvoke(self, *_args, **_kwargs): + return AIMessage(content=self._response_text) + + +@pytest.mark.asyncio +async def test_file_write_intent_injects_contract_message(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="Create another random note for me")], + "turn_id": "123:456", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/ideas.md" + assert contract["turn_id"] == "123:456" + assert any( + "file_operation_contract" in str(msg.content) + for msg in result["messages"] + if hasattr(msg, "content") + ) + + +@pytest.mark.asyncio +async def test_non_write_intent_does_not_inject_contract_message(): + llm = _FakeLLM( + '{"intent":"file_read","confidence":0.88,"suggested_filename":null}' + ) + middleware = FileIntentMiddleware(llm=llm) + original_messages = [HumanMessage(content="Read /notes.md")] + state = {"messages": original_messages, "turn_id": "abc:def"} + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value + assert "messages" not in result + + +@pytest.mark.asyncio +async def test_file_write_null_filename_uses_semantic_default_path(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.74,"suggested_filename":null}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="create a random markdown file")], + "turn_id": "turn:1", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/notes.md" + + +@pytest.mark.asyncio +async def test_file_write_null_filename_infers_json_extension(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.71,"suggested_filename":null}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="create a sample json config file")], + "turn_id": "turn:2", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/notes.json" + + +@pytest.mark.asyncio +async def test_file_write_txt_suggestion_is_normalized_to_markdown(): + llm = _FakeLLM( + '{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}' + ) + middleware = FileIntentMiddleware(llm=llm) + state = { + "messages": [HumanMessage(content="create a random file")], + "turn_id": "turn:3", + } + + result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] + + assert result is not None + contract = result["file_operation_contract"] + assert contract["intent"] == FileOperationIntent.FILE_WRITE.value + assert contract["suggested_path"] == "/random.md" + From 739345671b06f5b13566e2a726973ecba0fe3f67 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 23 Apr 2026 11:40:21 +0200 Subject: [PATCH 113/299] fix: break circular import in llm_service and kb_sync_service files --- .../app/services/confluence/kb_sync_service.py | 5 ++++- .../app/services/dropbox/kb_sync_service.py | 3 ++- .../app/services/gmail/kb_sync_service.py | 3 ++- .../app/services/google_calendar/kb_sync_service.py | 5 ++++- .../app/services/google_drive/kb_sync_service.py | 3 ++- .../app/services/jira/kb_sync_service.py | 5 ++++- .../app/services/linear/kb_sync_service.py | 5 ++++- surfsense_backend/app/services/llm_service.py | 11 ++++++++++- .../app/services/notion/kb_sync_service.py | 5 ++++- .../app/services/onedrive/kb_sync_service.py | 3 ++- 10 files changed, 38 insertions(+), 10 deletions(-) diff --git a/surfsense_backend/app/services/confluence/kb_sync_service.py b/surfsense_backend/app/services/confluence/kb_sync_service.py index f786a9920..cae2bef88 100644 --- a/surfsense_backend/app/services/confluence/kb_sync_service.py +++ b/surfsense_backend/app/services/confluence/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.confluence_history import ConfluenceHistoryConnector from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -66,6 +65,8 @@ class ConfluenceKBSyncService: if dup: content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -184,6 +185,8 @@ class ConfluenceKBSyncService: space_id = (document.document_metadata or {}).get("space_id", "") + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/dropbox/kb_sync_service.py b/surfsense_backend/app/services/dropbox/kb_sync_service.py index 2a74bdf4b..9d1951013 100644 --- a/surfsense_backend/app/services/dropbox/kb_sync_service.py +++ b/surfsense_backend/app/services/dropbox/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType from app.indexing_pipeline.document_hashing import compute_identifier_hash -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -73,6 +72,8 @@ class DropboxKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_backend/app/services/gmail/kb_sync_service.py b/surfsense_backend/app/services/gmail/kb_sync_service.py index b3b50d305..885ee4b94 100644 --- a/surfsense_backend/app/services/gmail/kb_sync_service.py +++ b/surfsense_backend/app/services/gmail/kb_sync_service.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -78,6 +77,8 @@ class GmailKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 3cda02b9b..20426f3bc 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -14,7 +14,6 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService: if not indexable_content: return {"status": "error", "message": "Event produced empty content"} + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/google_drive/kb_sync_service.py b/surfsense_backend/app/services/google_drive/kb_sync_service.py index 92a39f7b9..0a8eb47a6 100644 --- a/surfsense_backend/app/services/google_drive/kb_sync_service.py +++ b/surfsense_backend/app/services/google_drive/kb_sync_service.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -75,6 +74,8 @@ class GoogleDriveKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_backend/app/services/jira/kb_sync_service.py b/surfsense_backend/app/services/jira/kb_sync_service.py index 4d2a66e52..8e88bee81 100644 --- a/surfsense_backend/app/services/jira/kb_sync_service.py +++ b/surfsense_backend/app/services/jira/kb_sync_service.py @@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.jira_history import JiraHistoryConnector from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -75,6 +74,8 @@ class JiraKBSyncService: if dup: content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -190,6 +191,8 @@ class JiraKBSyncService: state = formatted.get("status", "Unknown") comment_count = len(formatted.get("comments", [])) + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/linear/kb_sync_service.py b/surfsense_backend/app/services/linear/kb_sync_service.py index dab42af55..471227602 100644 --- a/surfsense_backend/app/services/linear/kb_sync_service.py +++ b/surfsense_backend/app/services/linear/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.linear_connector import LinearConnector from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -85,6 +84,8 @@ class LinearKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -226,6 +227,8 @@ class LinearKBSyncService: comment_count = len(formatted_issue.get("comments", [])) formatted_issue.get("description", "") + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 79a72dd25..942a9b7af 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -7,7 +7,6 @@ from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.llm_config import SanitizedChatLiteLLM from app.config import config from app.db import NewLLMConfig, SearchSpace from app.services.llm_router_service import ( @@ -204,6 +203,8 @@ async def validate_llm_config( if litellm_params: litellm_kwargs.update(litellm_params) + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + llm = SanitizedChatLiteLLM(**litellm_kwargs) # Run the test call in a worker thread with a hard timeout. Some @@ -377,6 +378,8 @@ async def get_search_space_llm_instance( if disable_streaming: litellm_kwargs["disable_streaming"] = True + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) # Get the LLM configuration from database (NewLLMConfig) @@ -454,6 +457,8 @@ async def get_search_space_llm_instance( if disable_streaming: litellm_kwargs["disable_streaming"] = True + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) except Exception as e: @@ -555,6 +560,8 @@ async def get_vision_llm( if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) result = await session.execute( @@ -588,6 +595,8 @@ async def get_vision_llm( if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) except Exception as e: diff --git a/surfsense_backend/app/services/notion/kb_sync_service.py b/surfsense_backend/app/services/notion/kb_sync_service.py index be177c7ca..b10d1b157 100644 --- a/surfsense_backend/app/services/notion/kb_sync_service.py +++ b/surfsense_backend/app/services/notion/kb_sync_service.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -74,6 +73,8 @@ class NotionKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -244,6 +245,8 @@ class NotionKBSyncService: f"Final content length: {len(full_content)} chars, verified={content_verified}" ) + from app.services.llm_service import get_user_long_context_llm + logger.debug("Generating summary and embeddings") user_llm = await get_user_long_context_llm( self.db_session, diff --git a/surfsense_backend/app/services/onedrive/kb_sync_service.py b/surfsense_backend/app/services/onedrive/kb_sync_service.py index 962c19fc9..e9b2e38ea 100644 --- a/surfsense_backend/app/services/onedrive/kb_sync_service.py +++ b/surfsense_backend/app/services/onedrive/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType from app.indexing_pipeline.document_hashing import compute_identifier_hash -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -73,6 +72,8 @@ class OneDriveKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, From 42d2d2222ecbe674ebfcb75a64cc97ef96fc5e9b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:44:12 +0530 Subject: [PATCH 114/299] feat(filesystem): add local folder backend and verification coverage --- .../middleware/local_folder_backend.py | 316 ++++++++++++++++++ .../middleware/test_filesystem_backends.py | 37 ++ .../test_filesystem_verification.py | 64 ++++ .../middleware/test_local_folder_backend.py | 59 ++++ 4 files changed, 476 insertions(+) create mode 100644 surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py create mode 100644 surfsense_backend/tests/unit/middleware/test_filesystem_backends.py create mode 100644 surfsense_backend/tests/unit/middleware/test_filesystem_verification.py create mode 100644 surfsense_backend/tests/unit/middleware/test_local_folder_backend.py diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py new file mode 100644 index 000000000..60d967053 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -0,0 +1,316 @@ +"""Desktop local-folder filesystem backend for deepagents tools.""" + +from __future__ import annotations + +import asyncio +import fnmatch +import os +import threading +from pathlib import Path + +from deepagents.backends.protocol import ( + EditResult, + FileDownloadResponse, + FileInfo, + FileUploadResponse, + GrepMatch, + WriteResult, +) +from deepagents.backends.utils import ( + create_file_data, + format_read_response, + perform_string_replacement, +) + +_INVALID_PATH = "invalid_path" +_FILE_NOT_FOUND = "file_not_found" +_IS_DIRECTORY = "is_directory" + + +class LocalFolderBackend: + """Filesystem backend rooted to a single local folder.""" + + def __init__(self, root_path: str) -> None: + root = Path(root_path).expanduser().resolve() + if not root.exists() or not root.is_dir(): + msg = f"Local filesystem root does not exist or is not a directory: {root_path}" + raise ValueError(msg) + self._root = root + self._locks: dict[str, threading.Lock] = {} + self._locks_mu = threading.Lock() + + def _lock_for(self, path: str) -> threading.Lock: + with self._locks_mu: + if path not in self._locks: + self._locks[path] = threading.Lock() + return self._locks[path] + + def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path: + if not virtual_path.startswith("/"): + msg = f"Invalid path (must be absolute): {virtual_path}" + raise ValueError(msg) + rel = virtual_path.lstrip("/") + candidate = self._root if rel == "" else (self._root / rel) + resolved = candidate.resolve() + if not allow_root and resolved == self._root: + msg = "Path must refer to a file or child directory under root" + raise ValueError(msg) + if not resolved.is_relative_to(self._root): + msg = f"Path escapes local filesystem root: {virtual_path}" + raise ValueError(msg) + return resolved + + @staticmethod + def _to_virtual(path: Path, root: Path) -> str: + rel = path.relative_to(root).as_posix() + return "/" if rel == "." else f"/{rel}" + + def _write_text_atomic(self, path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + temp_path = path.with_suffix(f"{path.suffix}.tmp") + temp_path.write_text(content, encoding="utf-8") + os.replace(temp_path, path) + + def ls_info(self, path: str) -> list[FileInfo]: + try: + target = self._resolve_virtual(path, allow_root=True) + except ValueError: + return [] + if not target.exists() or not target.is_dir(): + return [] + infos: list[FileInfo] = [] + for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())): + infos.append( + FileInfo( + path=self._to_virtual(child, self._root), + is_dir=child.is_dir(), + size=child.stat().st_size if child.is_file() else 0, + modified_at=str(child.stat().st_mtime), + ) + ) + return infos + + async def als_info(self, path: str) -> list[FileInfo]: + return await asyncio.to_thread(self.ls_info, path) + + def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: + try: + path = self._resolve_virtual(file_path) + except ValueError: + return f"Error: Invalid path '{file_path}'" + if not path.exists(): + return f"Error: File '{file_path}' not found" + if not path.is_file(): + return f"Error: Path '{file_path}' is not a file" + content = path.read_text(encoding="utf-8", errors="replace") + file_data = create_file_data(content) + return format_read_response(file_data, offset, limit) + + async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: + return await asyncio.to_thread(self.read, file_path, offset, limit) + + def read_raw(self, file_path: str) -> str: + """Read raw file text without line-number formatting.""" + try: + path = self._resolve_virtual(file_path) + except ValueError: + return f"Error: Invalid path '{file_path}'" + if not path.exists(): + return f"Error: File '{file_path}' not found" + if not path.is_file(): + return f"Error: Path '{file_path}' is not a file" + return path.read_text(encoding="utf-8", errors="replace") + + async def aread_raw(self, file_path: str) -> str: + """Async variant of read_raw.""" + return await asyncio.to_thread(self.read_raw, file_path) + + def write(self, file_path: str, content: str) -> WriteResult: + try: + path = self._resolve_virtual(file_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{file_path}'") + lock = self._lock_for(file_path) + with lock: + if path.exists(): + return WriteResult( + error=( + f"Cannot write to {file_path} because it already exists. " + "Read and then make an edit, or write to a new path." + ) + ) + self._write_text_atomic(path, content) + return WriteResult(path=file_path, files_update=None) + + async def awrite(self, file_path: str, content: str) -> WriteResult: + return await asyncio.to_thread(self.write, file_path, content) + + def edit( + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + try: + path = self._resolve_virtual(file_path) + except ValueError: + return EditResult(error=f"Error: Invalid path '{file_path}'") + lock = self._lock_for(file_path) + with lock: + if not path.exists() or not path.is_file(): + return EditResult(error=f"Error: File '{file_path}' not found") + content = path.read_text(encoding="utf-8", errors="replace") + result = perform_string_replacement(content, old_string, new_string, replace_all) + if isinstance(result, str): + return EditResult(error=result) + updated_content, occurrences = result + self._write_text_atomic(path, updated_content) + return EditResult(path=file_path, files_update=None, occurrences=int(occurrences)) + + async def aedit( + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + return await asyncio.to_thread( + self.edit, file_path, old_string, new_string, replace_all + ) + + def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: + try: + base = self._resolve_virtual(path, allow_root=True) + except ValueError: + return [] + + if pattern.startswith("/"): + search_base = self._root + normalized_pattern = pattern.lstrip("/") + else: + search_base = base + normalized_pattern = pattern + + matches: list[FileInfo] = [] + for hit in search_base.glob(normalized_pattern): + try: + resolved = hit.resolve() + if not resolved.is_relative_to(self._root): + continue + except Exception: + continue + matches.append( + FileInfo( + path=self._to_virtual(resolved, self._root), + is_dir=resolved.is_dir(), + size=resolved.stat().st_size if resolved.is_file() else 0, + modified_at=str(resolved.stat().st_mtime), + ) + ) + return matches + + async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: + return await asyncio.to_thread(self.glob_info, pattern, path) + + def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]: + base_virtual = path or "/" + try: + base = self._resolve_virtual(base_virtual, allow_root=True) + except ValueError: + return [] + if not base.exists(): + return [] + + candidates = [p for p in base.rglob("*") if p.is_file()] + if glob: + candidates = [ + p + for p in candidates + if fnmatch.fnmatch(self._to_virtual(p, self._root), glob) + or fnmatch.fnmatch(p.name, glob) + ] + return candidates + + def grep_raw( + self, pattern: str, path: str | None = None, glob: str | None = None + ) -> list[GrepMatch] | str: + if not pattern: + return "Error: pattern cannot be empty" + matches: list[GrepMatch] = [] + for file_path in self._iter_candidate_files(path, glob): + try: + lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines() + except Exception: + continue + for idx, line in enumerate(lines, start=1): + if pattern in line: + matches.append( + GrepMatch( + path=self._to_virtual(file_path, self._root), + line=idx, + text=line, + ) + ) + return matches + + async def agrep_raw( + self, pattern: str, path: str | None = None, glob: str | None = None + ) -> list[GrepMatch] | str: + return await asyncio.to_thread(self.grep_raw, pattern, path, glob) + + def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + responses: list[FileUploadResponse] = [] + for virtual_path, content in files: + try: + target = self._resolve_virtual(virtual_path) + target.parent.mkdir(parents=True, exist_ok=True) + temp_path = target.with_suffix(f"{target.suffix}.tmp") + temp_path.write_bytes(content) + os.replace(temp_path, target) + responses.append(FileUploadResponse(path=virtual_path, error=None)) + except FileNotFoundError: + responses.append( + FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND) + ) + except IsADirectoryError: + responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY)) + except Exception: + responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH)) + return responses + + async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + return await asyncio.to_thread(self.upload_files, files) + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + responses: list[FileDownloadResponse] = [] + for virtual_path in paths: + try: + target = self._resolve_virtual(virtual_path) + if not target.exists(): + responses.append( + FileDownloadResponse( + path=virtual_path, content=None, error=_FILE_NOT_FOUND + ) + ) + continue + if target.is_dir(): + responses.append( + FileDownloadResponse( + path=virtual_path, content=None, error=_IS_DIRECTORY + ) + ) + continue + responses.append( + FileDownloadResponse( + path=virtual_path, content=target.read_bytes(), error=None + ) + ) + except Exception: + responses.append( + FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH) + ) + return responses + + async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]: + return await asyncio.to_thread(self.download_files, paths) diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py new file mode 100644 index 000000000..2377307f8 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py @@ -0,0 +1,37 @@ +from pathlib import Path + +import pytest + +from app.agents.new_chat.filesystem_backends import build_backend_resolver +from app.agents.new_chat.filesystem_selection import ( + ClientPlatform, + FilesystemMode, + FilesystemSelection, +) +from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend + +pytestmark = pytest.mark.unit + + +class _RuntimeStub: + state = {"files": {}} + + +def test_backend_resolver_returns_local_backend_for_local_mode(tmp_path: Path): + selection = FilesystemSelection( + mode=FilesystemMode.DESKTOP_LOCAL_FOLDER, + client_platform=ClientPlatform.DESKTOP, + local_root_path=str(tmp_path), + ) + resolver = build_backend_resolver(selection) + + backend = resolver(_RuntimeStub()) + assert isinstance(backend, LocalFolderBackend) + + +def test_backend_resolver_uses_cloud_mode_by_default(): + resolver = build_backend_resolver(FilesystemSelection()) + backend = resolver(_RuntimeStub()) + # StateBackend class name check keeps this test decoupled + # from internal deepagents runtime class identity. + assert backend.__class__.__name__ == "StateBackend" diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py new file mode 100644 index 000000000..9f6b162aa --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py @@ -0,0 +1,64 @@ +import pytest + +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware + +pytestmark = pytest.mark.unit + + +class _BackendWithRawRead: + def __init__(self, content: str) -> None: + self._content = content + + def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str: + del file_path, offset, limit + return " 1\tline1\n 2\tline2" + + async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str: + return self.read(file_path, offset, limit) + + def read_raw(self, file_path: str) -> str: + del file_path + return self._content + + async def aread_raw(self, file_path: str) -> str: + return self.read_raw(file_path) + + +class _RuntimeNoSuggestedPath: + state = {"file_operation_contract": {}} + + +def test_verify_written_content_prefers_raw_sync() -> None: + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + expected = "line1\nline2" + backend = _BackendWithRawRead(expected) + + verify_error = middleware._verify_written_content_sync( + backend=backend, + path="/note.md", + expected_content=expected, + ) + + assert verify_error is None + + +def test_contract_suggested_path_falls_back_to_notes_md() -> None: + suggested = SurfSenseFilesystemMiddleware._get_contract_suggested_path( + _RuntimeNoSuggestedPath() + ) + assert suggested == "/notes.md" + + +@pytest.mark.asyncio +async def test_verify_written_content_prefers_raw_async() -> None: + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + expected = "line1\nline2" + backend = _BackendWithRawRead(expected) + + verify_error = await middleware._verify_written_content_async( + backend=backend, + path="/note.md", + expected_content=expected, + ) + + assert verify_error is None diff --git a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py new file mode 100644 index 000000000..3484a2cc4 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py @@ -0,0 +1,59 @@ +from pathlib import Path + +import pytest + +from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend + +pytestmark = pytest.mark.unit + + +def test_local_backend_write_read_edit_roundtrip(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + + write = backend.write("/notes/test.md", "line1\nline2") + assert write.error is None + assert write.path == "/notes/test.md" + + read = backend.read("/notes/test.md", offset=0, limit=20) + assert "line1" in read + assert "line2" in read + + edit = backend.edit("/notes/test.md", "line2", "updated") + assert edit.error is None + assert edit.occurrences == 1 + + read_after = backend.read("/notes/test.md", offset=0, limit=20) + assert "updated" in read_after + + +def test_local_backend_blocks_path_escape(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + + result = backend.write("/../../etc/passwd", "bad") + assert result.error is not None + assert "Invalid path" in result.error + + +def test_local_backend_glob_and_grep(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "docs").mkdir() + (tmp_path / "docs" / "a.txt").write_text("hello world\n") + (tmp_path / "docs" / "b.md").write_text("hello markdown\n") + + infos = backend.glob_info("**/*.txt", "/docs") + paths = {info["path"] for info in infos} + assert "/docs/a.txt" in paths + + grep = backend.grep_raw("hello", "/docs", "*.md") + assert isinstance(grep, list) + assert any(match["path"] == "/docs/b.md" for match in grep) + + +def test_local_backend_read_raw_returns_exact_content(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + expected = "# Title\n\nline 1\nline 2\n" + write = backend.write("/notes/raw.md", expected) + assert write.error is None + + raw = backend.read_raw("/notes/raw.md") + assert raw == expected From 1eadecee235924c707221beed860d43994583830 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:45:33 +0530 Subject: [PATCH 115/299] feat(new-chat): integrate filesystem flow into agent pipeline --- .../app/agents/new_chat/chat_deepagent.py | 13 + .../app/agents/new_chat/context.py | 13 +- .../agents/new_chat/middleware/filesystem.py | 223 ++++++++++++++++-- .../new_chat/middleware/knowledge_search.py | 6 + surfsense_backend/app/app.py | 18 ++ surfsense_backend/app/config/__init__.py | 3 + .../app/routes/new_chat_routes.py | 80 +++++++ surfsense_backend/app/schemas/new_chat.py | 9 + .../app/tasks/chat/stream_new_chat.py | 186 ++++++++++++++- .../unit/test_stream_new_chat_contract.py | 48 ++++ 10 files changed, 574 insertions(+), 25 deletions(-) create mode 100644 surfsense_backend/tests/unit/test_stream_new_chat_contract.py diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index a901a7519..ff8215eff 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -33,9 +33,12 @@ from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.new_chat.filesystem_backends import build_backend_resolver +from app.agents.new_chat.filesystem_selection import FilesystemSelection from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.middleware import ( DedupHITLToolCallsMiddleware, + FileIntentMiddleware, KnowledgeBaseSearchMiddleware, MemoryInjectionMiddleware, SurfSenseFilesystemMiddleware, @@ -164,6 +167,7 @@ async def create_surfsense_deep_agent( thread_visibility: ChatVisibility | None = None, mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, + filesystem_selection: FilesystemSelection | None = None, ): """ Create a SurfSense deep agent with configurable tools and prompts. @@ -238,6 +242,8 @@ async def create_surfsense_deep_agent( ) """ _t_agent_total = time.perf_counter() + filesystem_selection = filesystem_selection or FilesystemSelection() + backend_resolver = build_backend_resolver(filesystem_selection) # Discover available connectors and document types for this search space available_connectors: list[str] | None = None @@ -439,7 +445,10 @@ async def create_surfsense_deep_agent( gp_middleware = [ TodoListMiddleware(), _memory_middleware, + FileIntentMiddleware(llm=llm), SurfSenseFilesystemMiddleware( + backend=backend_resolver, + filesystem_mode=filesystem_selection.mode, search_space_id=search_space_id, created_by_id=user_id, thread_id=thread_id, @@ -460,15 +469,19 @@ async def create_surfsense_deep_agent( deepagent_middleware = [ TodoListMiddleware(), _memory_middleware, + FileIntentMiddleware(llm=llm), KnowledgeBaseSearchMiddleware( llm=llm, search_space_id=search_space_id, + filesystem_mode=filesystem_selection.mode, available_connectors=available_connectors, available_document_types=available_document_types, mentioned_document_ids=mentioned_document_ids, anon_session_id=anon_session_id, ), SurfSenseFilesystemMiddleware( + backend=backend_resolver, + filesystem_mode=filesystem_selection.mode, search_space_id=search_space_id, created_by_id=user_id, thread_id=thread_id, diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py index da113adf4..c1fe45aaa 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/new_chat/context.py @@ -4,7 +4,15 @@ Context schema definitions for SurfSense agents. This module defines the custom state schema used by the SurfSense deep agent. """ -from typing import TypedDict +from typing import NotRequired, TypedDict + + +class FileOperationContractState(TypedDict): + intent: str + confidence: float + suggested_path: str + timestamp: str + turn_id: str class SurfSenseContextSchema(TypedDict): @@ -24,5 +32,8 @@ class SurfSenseContextSchema(TypedDict): """ search_space_id: int + file_operation_contract: NotRequired[FileOperationContractState] + turn_id: NotRequired[str] + request_id: NotRequired[str] # These are runtime-injected and won't be serialized # db_session and connector_service are passed when invoking the agent diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index bcd544d61..0fa2085fc 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -32,6 +32,7 @@ from app.agents.new_chat.sandbox import ( get_or_create_sandbox, is_sandbox_enabled, ) +from app.agents.new_chat.filesystem_selection import FilesystemMode from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session from app.indexing_pipeline.document_chunker import chunk_text from app.utils.document_converters import ( @@ -50,6 +51,8 @@ SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions - Read files before editing — understand existing content before making changes. - Mimic existing style, naming conventions, and patterns. +- Never claim a file was created/updated unless filesystem tool output confirms success. +- If a file write/edit fails, explicitly report the failure. ## Filesystem Tools @@ -109,13 +112,20 @@ Usage: - Use chunk IDs (``) as citations in answers. """ -SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only). +SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only). Use this to create scratch/working files during the conversation. Files created here are ephemeral and will not be saved to the user's knowledge base. To permanently save a document to the user's knowledge base, use the `save_document` tool instead. + +Supported outputs include common LLM-friendly text formats like markdown, json, +yaml, csv, xml, html, css, sql, and code files. + +When creating content from open-ended prompts, produce concrete and useful text, +not placeholders. Avoid adding dates/timestamps unless the user explicitly asks +for them. """ SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files. @@ -182,11 +192,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): def __init__( self, *, + backend: Any = None, + filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, search_space_id: int | None = None, created_by_id: str | None = None, thread_id: int | str | None = None, tool_token_limit_before_evict: int | None = 20000, ) -> None: + self._filesystem_mode = filesystem_mode self._search_space_id = search_space_id self._created_by_id = created_by_id self._thread_id = thread_id @@ -204,8 +217,15 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): " extract the data, write it as a clean file (CSV, JSON, etc.)," " and then run your code against it." ) + if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + system_prompt += ( + "\n\n## Local Folder Mode" + "\n\nThis chat is running in desktop local-folder mode." + " Keep all file operations local. Do not use save_document." + ) super().__init__( + backend=backend, system_prompt=system_prompt, custom_tool_descriptions={ "ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION, @@ -219,7 +239,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): max_execute_timeout=self._MAX_EXECUTE_TIMEOUT, ) self.tools = [t for t in self.tools if t.name != "execute"] - self.tools.append(self._create_save_document_tool()) + if self._should_persist_documents(): + self.tools.append(self._create_save_document_tool()) if self._sandbox_available: self.tools.append(self._create_execute_code_tool()) @@ -637,15 +658,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): runtime: ToolRuntime[None, FilesystemState], ) -> Command | str: resolved_backend = self._get_backend(runtime) + target_path = self._resolve_write_target_path(file_path, runtime) try: - validated_path = validate_path(file_path) + validated_path = validate_path(target_path) except ValueError as exc: return f"Error: {exc}" res: WriteResult = resolved_backend.write(validated_path, content) if res.error: return res.error + verify_error = self._verify_written_content_sync( + backend=resolved_backend, + path=validated_path, + expected_content=content, + ) + if verify_error: + return verify_error - if not self._is_kb_document(validated_path): + if self._should_persist_documents() and not self._is_kb_document( + validated_path + ): persist_result = self._run_async_blocking( self._persist_new_document( file_path=validated_path, content=content @@ -682,15 +713,25 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): runtime: ToolRuntime[None, FilesystemState], ) -> Command | str: resolved_backend = self._get_backend(runtime) + target_path = self._resolve_write_target_path(file_path, runtime) try: - validated_path = validate_path(file_path) + validated_path = validate_path(target_path) except ValueError as exc: return f"Error: {exc}" res: WriteResult = await resolved_backend.awrite(validated_path, content) if res.error: return res.error + verify_error = await self._verify_written_content_async( + backend=resolved_backend, + path=validated_path, + expected_content=content, + ) + if verify_error: + return verify_error - if not self._is_kb_document(validated_path): + if self._should_persist_documents() and not self._is_kb_document( + validated_path + ): persist_result = await self._persist_new_document( file_path=validated_path, content=content, @@ -726,6 +767,124 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): """Return True for paths under /documents/ (KB-sourced, XML-wrapped).""" return path.startswith("/documents/") + def _should_persist_documents(self) -> bool: + """Only cloud mode persists file content to Document/Chunk tables.""" + return self._filesystem_mode == FilesystemMode.CLOUD + + @staticmethod + def _get_contract_suggested_path(runtime: ToolRuntime[None, FilesystemState]) -> str: + contract = runtime.state.get("file_operation_contract") or {} + suggested = contract.get("suggested_path") + if isinstance(suggested, str) and suggested.strip(): + return suggested.strip() + return "/notes.md" + + def _resolve_write_target_path( + self, + file_path: str, + runtime: ToolRuntime[None, FilesystemState], + ) -> str: + candidate = file_path.strip() + if not candidate: + return self._get_contract_suggested_path(runtime) + if not candidate.startswith("/"): + return f"/{candidate.lstrip('/')}" + return candidate + + @staticmethod + def _is_error_text(value: str) -> bool: + return value.startswith("Error:") + + @staticmethod + def _read_for_verification_sync(backend: Any, path: str) -> str: + read_raw = getattr(backend, "read_raw", None) + if callable(read_raw): + return read_raw(path) + return backend.read(path, offset=0, limit=200000) + + @staticmethod + async def _read_for_verification_async(backend: Any, path: str) -> str: + aread_raw = getattr(backend, "aread_raw", None) + if callable(aread_raw): + return await aread_raw(path) + return await backend.aread(path, offset=0, limit=200000) + + def _verify_written_content_sync( + self, + *, + backend: Any, + path: str, + expected_content: str, + ) -> str | None: + actual = self._read_for_verification_sync(backend, path) + if self._is_error_text(actual): + return f"Error: could not verify written file '{path}'." + if actual.rstrip() != expected_content.rstrip(): + return ( + "Error: file write verification failed; expected content was not fully written " + f"to '{path}'." + ) + return None + + async def _verify_written_content_async( + self, + *, + backend: Any, + path: str, + expected_content: str, + ) -> str | None: + actual = await self._read_for_verification_async(backend, path) + if self._is_error_text(actual): + return f"Error: could not verify written file '{path}'." + if actual.rstrip() != expected_content.rstrip(): + return ( + "Error: file write verification failed; expected content was not fully written " + f"to '{path}'." + ) + return None + + def _verify_edited_content_sync( + self, + *, + backend: Any, + path: str, + new_string: str, + ) -> tuple[str | None, str | None]: + updated_content = self._read_for_verification_sync(backend, path) + if self._is_error_text(updated_content): + return ( + f"Error: could not verify edited file '{path}'.", + None, + ) + if new_string and new_string not in updated_content: + return ( + "Error: edit verification failed; updated content was not found in " + f"'{path}'.", + None, + ) + return None, updated_content + + async def _verify_edited_content_async( + self, + *, + backend: Any, + path: str, + new_string: str, + ) -> tuple[str | None, str | None]: + updated_content = await self._read_for_verification_async(backend, path) + if self._is_error_text(updated_content): + return ( + f"Error: could not verify edited file '{path}'.", + None, + ) + if new_string and new_string not in updated_content: + return ( + "Error: edit verification failed; updated content was not found in " + f"'{path}'.", + None, + ) + return None, updated_content + def _create_edit_file_tool(self) -> BaseTool: """Create edit_file with DB persistence (skipped for KB documents).""" tool_description = ( @@ -754,8 +913,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = False, ) -> Command | str: resolved_backend = self._get_backend(runtime) + target_path = self._resolve_write_target_path(file_path, runtime) try: - validated_path = validate_path(file_path) + validated_path = validate_path(target_path) except ValueError as exc: return f"Error: {exc}" res: EditResult = resolved_backend.edit( @@ -767,13 +927,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if res.error: return res.error - if not self._is_kb_document(validated_path): - read_result = resolved_backend.read( - validated_path, offset=0, limit=200000 - ) - if read_result.error or read_result.file_data is None: - return f"Error: could not reload edited file '{validated_path}' for persistence." - updated_content = read_result.file_data["content"] + verify_error, updated_content = self._verify_edited_content_sync( + backend=resolved_backend, + path=validated_path, + new_string=new_string, + ) + if verify_error: + return verify_error + + if self._should_persist_documents() and not self._is_kb_document( + validated_path + ): + if updated_content is None: + return ( + f"Error: could not reload edited file '{validated_path}' for " + "persistence." + ) persist_result = self._run_async_blocking( self._persist_edited_document( file_path=validated_path, @@ -818,8 +987,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = False, ) -> Command | str: resolved_backend = self._get_backend(runtime) + target_path = self._resolve_write_target_path(file_path, runtime) try: - validated_path = validate_path(file_path) + validated_path = validate_path(target_path) except ValueError as exc: return f"Error: {exc}" res: EditResult = await resolved_backend.aedit( @@ -831,13 +1001,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if res.error: return res.error - if not self._is_kb_document(validated_path): - read_result = await resolved_backend.aread( - validated_path, offset=0, limit=200000 - ) - if read_result.error or read_result.file_data is None: - return f"Error: could not reload edited file '{validated_path}' for persistence." - updated_content = read_result.file_data["content"] + verify_error, updated_content = await self._verify_edited_content_async( + backend=resolved_backend, + path=validated_path, + new_string=new_string, + ) + if verify_error: + return verify_error + + if self._should_persist_documents() and not self._is_kb_document( + validated_path + ): + if updated_content is None: + return ( + f"Error: could not reload edited file '{validated_path}' for " + "persistence." + ) persist_error = await self._persist_edited_document( file_path=validated_path, updated_content=updated_content, diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index c7bbe62e0..51378a013 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -28,6 +28,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range +from app.agents.new_chat.filesystem_selection import FilesystemMode from app.db import ( NATIVE_TO_LEGACY_DOCTYPE, Chunk, @@ -857,6 +858,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] *, llm: BaseChatModel | None = None, search_space_id: int, + filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, available_connectors: list[str] | None = None, available_document_types: list[str] | None = None, top_k: int = 10, @@ -865,6 +867,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] ) -> None: self.llm = llm self.search_space_id = search_space_id + self.filesystem_mode = filesystem_mode self.available_connectors = available_connectors self.available_document_types = available_document_types self.top_k = top_k @@ -996,6 +999,9 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] messages = state.get("messages") or [] if not messages: return None + if self.filesystem_mode != FilesystemMode.CLOUD: + # Local-folder mode should not seed cloud KB documents into filesystem. + return None last_human = None for msg in reversed(messages): diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index a1795853a..016c2de42 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -141,6 +141,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons exc.status_code, message, ) + elif exc.status_code >= 400: + _error_logger.warning( + "[%s] %s %s - HTTPException %d: %s", + rid, + request.method, + request.url.path, + exc.status_code, + message, + ) if should_sanitize: message = GENERIC_5XX_MESSAGE err_code = "INTERNAL_ERROR" @@ -170,6 +179,15 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons exc.status_code, detail, ) + elif exc.status_code >= 400: + _error_logger.warning( + "[%s] %s %s - HTTPException %d: %s", + rid, + request.method, + request.url.path, + exc.status_code, + detail, + ) if should_sanitize: detail = GENERIC_5XX_MESSAGE code = _status_to_code(exc.status_code, detail) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index a515e9044..bd97d2bb1 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -339,6 +339,9 @@ class Config: # self-hosted: Full access to local file system connectors (Obsidian, etc.) # cloud: Only cloud-based connectors available DEPLOYMENT_MODE = os.getenv("SURFSENSE_DEPLOYMENT_MODE", "self-hosted") + ENABLE_DESKTOP_LOCAL_FILESYSTEM = ( + os.getenv("ENABLE_DESKTOP_LOCAL_FILESYSTEM", "FALSE").upper() == "TRUE" + ) @classmethod def is_self_hosted(cls) -> bool: diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b914b297e..5e8e24c4a 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -22,6 +22,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload +from app.agents.new_chat.filesystem_selection import ( + ClientPlatform, + FilesystemMode, + FilesystemSelection, +) +from app.config import config from app.db import ( ChatComment, ChatVisibility, @@ -63,6 +69,51 @@ _background_tasks: set[asyncio.Task] = set() router = APIRouter() +def _resolve_filesystem_selection( + *, + mode: str, + client_platform: str, + local_root: str | None, +) -> FilesystemSelection: + """Validate and normalize filesystem mode settings from request payload.""" + try: + resolved_mode = FilesystemMode(mode) + except ValueError as exc: + raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc + try: + resolved_platform = ClientPlatform(client_platform) + except ValueError as exc: + raise HTTPException(status_code=400, detail="Invalid client_platform") from exc + + if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM: + raise HTTPException( + status_code=400, + detail="Desktop local filesystem mode is disabled on this deployment.", + ) + if resolved_platform != ClientPlatform.DESKTOP: + raise HTTPException( + status_code=400, + detail="desktop_local_folder mode is only available on desktop runtime.", + ) + if not local_root or not local_root.strip(): + raise HTTPException( + status_code=400, + detail="local_filesystem_root is required for desktop_local_folder mode.", + ) + return FilesystemSelection( + mode=resolved_mode, + client_platform=resolved_platform, + local_root_path=local_root.strip(), + ) + + return FilesystemSelection( + mode=FilesystemMode.CLOUD, + client_platform=resolved_platform, + local_root_path=None, + ) + + def _try_delete_sandbox(thread_id: int) -> None: """Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked.""" from app.agents.new_chat.sandbox import ( @@ -474,6 +525,11 @@ async def get_thread_messages( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + filesystem_selection = _resolve_filesystem_selection( + mode=request.filesystem_mode, + client_platform=request.client_platform, + local_root=request.local_filesystem_root, + ) # Get messages with their authors and token usage loaded messages_result = await session.execute( @@ -1098,6 +1154,7 @@ async def list_agent_tools( @router.post("/new_chat") async def handle_new_chat( request: NewChatRequest, + http_request: Request, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): @@ -1133,6 +1190,11 @@ async def handle_new_chat( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + filesystem_selection = _resolve_filesystem_selection( + mode=request.filesystem_mode, + client_platform=request.client_platform, + local_root=request.local_filesystem_root, + ) # Get search space to check LLM config preferences search_space_result = await session.execute( @@ -1175,6 +1237,8 @@ async def handle_new_chat( thread_visibility=thread.visibility, current_user_display_name=user.display_name or "A team member", disabled_tools=request.disabled_tools, + filesystem_selection=filesystem_selection, + request_id=getattr(http_request.state, "request_id", "unknown"), ), media_type="text/event-stream", headers={ @@ -1202,6 +1266,7 @@ async def handle_new_chat( async def regenerate_response( thread_id: int, request: RegenerateRequest, + http_request: Request, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): @@ -1247,6 +1312,11 @@ async def regenerate_response( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + filesystem_selection = _resolve_filesystem_selection( + mode=request.filesystem_mode, + client_platform=request.client_platform, + local_root=request.local_filesystem_root, + ) # Get the checkpointer and state history checkpointer = await get_checkpointer() @@ -1412,6 +1482,8 @@ async def regenerate_response( thread_visibility=thread.visibility, current_user_display_name=user.display_name or "A team member", disabled_tools=request.disabled_tools, + filesystem_selection=filesystem_selection, + request_id=getattr(http_request.state, "request_id", "unknown"), ): yield chunk streaming_completed = True @@ -1477,6 +1549,7 @@ async def regenerate_response( async def resume_chat( thread_id: int, request: ResumeRequest, + http_request: Request, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): @@ -1498,6 +1571,11 @@ async def resume_chat( ) await check_thread_access(session, thread, user) + filesystem_selection = _resolve_filesystem_selection( + mode=request.filesystem_mode, + client_platform=request.client_platform, + local_root=request.local_filesystem_root, + ) search_space_result = await session.execute( select(SearchSpace).filter(SearchSpace.id == request.search_space_id) @@ -1526,6 +1604,8 @@ async def resume_chat( user_id=str(user.id), llm_config_id=llm_config_id, thread_visibility=thread.visibility, + filesystem_selection=filesystem_selection, + request_id=getattr(http_request.state, "request_id", "unknown"), ), media_type="text/event-stream", headers={ diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index e523657a4..593127c7e 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -184,6 +184,9 @@ class NewChatRequest(BaseModel): disabled_tools: list[str] | None = ( None # Optional list of tool names the user has disabled from the UI ) + filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" + client_platform: Literal["web", "desktop"] = "web" + local_filesystem_root: str | None = None class RegenerateRequest(BaseModel): @@ -204,6 +207,9 @@ class RegenerateRequest(BaseModel): mentioned_document_ids: list[int] | None = None mentioned_surfsense_doc_ids: list[int] | None = None disabled_tools: list[str] | None = None + filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" + client_platform: Literal["web", "desktop"] = "web" + local_filesystem_root: str | None = None # ============================================================================= @@ -227,6 +233,9 @@ class ResumeDecision(BaseModel): class ResumeRequest(BaseModel): search_space_id: int decisions: list[ResumeDecision] + filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" + client_platform: Literal["web", "desktop"] = "web" + local_filesystem_root: str | None = None # ============================================================================= diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 4810f02e6..d551f3fd5 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -30,6 +30,8 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.config import config from app.agents.new_chat.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, @@ -145,6 +147,85 @@ class StreamResult: interrupt_value: dict[str, Any] | None = None sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False + request_id: str | None = None + turn_id: str = "" + filesystem_mode: str = "cloud" + client_platform: str = "web" + intent_detected: str = "chat_only" + intent_confidence: float = 0.0 + write_attempted: bool = False + write_succeeded: bool = False + verification_succeeded: bool = False + commit_gate_passed: bool = True + commit_gate_reason: str = "" + + +def _safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _tool_output_to_text(tool_output: Any) -> str: + if isinstance(tool_output, dict): + if isinstance(tool_output.get("result"), str): + return tool_output["result"] + if isinstance(tool_output.get("error"), str): + return tool_output["error"] + return json.dumps(tool_output, ensure_ascii=False) + return str(tool_output) + + +def _tool_output_has_error(tool_output: Any) -> bool: + if isinstance(tool_output, dict): + if tool_output.get("error"): + return True + result = tool_output.get("result") + if isinstance(result, str) and result.strip().lower().startswith("error:"): + return True + return False + if isinstance(tool_output, str): + return tool_output.strip().lower().startswith("error:") + return False + + +def _contract_enforcement_active(result: StreamResult) -> bool: + # Keep policy deterministic with no env-driven progression modes: + # enforce the file-operation contract only in desktop local-folder mode. + return result.filesystem_mode == "desktop_local_folder" + + +def _evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]: + if result.intent_detected != "file_write": + return True, "" + if not result.write_attempted: + return False, "no_write_attempt" + if not result.write_succeeded: + return False, "write_failed" + if not result.verification_succeeded: + return False, "verification_failed" + return True, "" + + +def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: + payload: dict[str, Any] = { + "stage": stage, + "request_id": result.request_id or "unknown", + "turn_id": result.turn_id or "unknown", + "chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown", + "filesystem_mode": result.filesystem_mode, + "client_platform": result.client_platform, + "intent_detected": result.intent_detected, + "intent_confidence": result.intent_confidence, + "write_attempted": result.write_attempted, + "write_succeeded": result.write_succeeded, + "verification_succeeded": result.verification_succeeded, + "commit_gate_passed": result.commit_gate_passed, + "commit_gate_reason": result.commit_gate_reason or None, + } + payload.update(extra) + _perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)) async def _stream_agent_events( @@ -239,6 +320,8 @@ async def _stream_agent_events( tool_name = event.get("name", "unknown_tool") run_id = event.get("run_id", "") tool_input = event.get("data", {}).get("input", {}) + if tool_name in ("write_file", "edit_file"): + result.write_attempted = True if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -514,6 +597,14 @@ async def _stream_agent_events( else: tool_output = {"result": str(raw_output) if raw_output else "completed"} + if tool_name in ("write_file", "edit_file"): + if _tool_output_has_error(tool_output): + # Keep successful evidence if a previous write/edit in this turn succeeded. + pass + else: + result.write_succeeded = True + result.verification_succeeded = True + tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" original_step_id = tool_step_ids.get( run_id, f"{step_prefix}-unknown-{run_id[:8]}" @@ -1143,10 +1234,59 @@ async def _stream_agent_events( if completion_event: yield completion_event + state = await agent.aget_state(config) + state_values = getattr(state, "values", {}) or {} + contract_state = state_values.get("file_operation_contract") or {} + contract_turn_id = contract_state.get("turn_id") + current_turn_id = config.get("configurable", {}).get("turn_id", "") + intent_value = contract_state.get("intent") + if ( + isinstance(intent_value, str) + and intent_value in ("chat_only", "file_write", "file_read") + and contract_turn_id == current_turn_id + ): + result.intent_detected = intent_value + if ( + isinstance(intent_value, str) + and intent_value in ( + "chat_only", + "file_write", + "file_read", + ) + and contract_turn_id != current_turn_id + ): + # Ignore stale intent contracts from previous turns/checkpoints. + result.intent_detected = "chat_only" + result.intent_confidence = ( + _safe_float(contract_state.get("confidence"), default=0.0) + if contract_turn_id == current_turn_id + else 0.0 + ) + + if result.intent_detected == "file_write": + result.commit_gate_passed, result.commit_gate_reason = ( + _evaluate_file_contract_outcome(result) + ) + if not result.commit_gate_passed: + if _contract_enforcement_active(result): + gate_notice = ( + "I could not complete the requested file write because no successful " + "write_file/edit_file operation was confirmed." + ) + gate_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(gate_text_id) + yield streaming_service.format_text_delta(gate_text_id, gate_notice) + yield streaming_service.format_text_end(gate_text_id) + yield streaming_service.format_terminal_info(gate_notice, "error") + accumulated_text = gate_notice + else: + result.commit_gate_passed = True + result.commit_gate_reason = "" + result.accumulated_text = accumulated_text result.agent_called_update_memory = called_update_memory + _log_file_contract("turn_outcome", result) - state = await agent.aget_state(config) is_interrupted = state.tasks and any(task.interrupts for task in state.tasks) if is_interrupted: result.is_interrupted = True @@ -1167,6 +1307,8 @@ async def stream_new_chat( thread_visibility: ChatVisibility | None = None, current_user_display_name: str | None = None, disabled_tools: list[str] | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -1194,6 +1336,20 @@ async def stream_new_chat( streaming_service = VercelStreamingService() stream_result = StreamResult() _t_total = time.perf_counter() + fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" + fs_platform = ( + filesystem_selection.client_platform.value if filesystem_selection else "web" + ) + stream_result.request_id = request_id + stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" + stream_result.filesystem_mode = fs_mode + stream_result.client_platform = fs_platform + _log_file_contract("turn_start", stream_result) + _perf_log.info( + "[stream_new_chat] filesystem_mode=%s client_platform=%s", + fs_mode, + fs_platform, + ) log_system_snapshot("stream_new_chat_START") from app.services.token_tracking_service import start_turn @@ -1329,6 +1485,7 @@ async def stream_new_chat( thread_visibility=visibility, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, ) _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 @@ -1435,6 +1592,8 @@ async def stream_new_chat( # We will use this to simulate group chat functionality in the future "messages": langchain_messages, "search_space_id": search_space_id, + "request_id": request_id or "unknown", + "turn_id": stream_result.turn_id, } _perf_log.info( @@ -1464,6 +1623,8 @@ async def stream_new_chat( # Configure LangGraph with thread_id for memory # If checkpoint_id is provided, fork from that checkpoint (for edit/reload) configurable = {"thread_id": str(chat_id)} + configurable["request_id"] = request_id or "unknown" + configurable["turn_id"] = stream_result.turn_id if checkpoint_id: configurable["checkpoint_id"] = checkpoint_id @@ -1871,10 +2032,26 @@ async def stream_resume_chat( user_id: str | None = None, llm_config_id: int = -1, thread_visibility: ChatVisibility | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, ) -> AsyncGenerator[str, None]: streaming_service = VercelStreamingService() stream_result = StreamResult() _t_total = time.perf_counter() + fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" + fs_platform = ( + filesystem_selection.client_platform.value if filesystem_selection else "web" + ) + stream_result.request_id = request_id + stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" + stream_result.filesystem_mode = fs_mode + stream_result.client_platform = fs_platform + _log_file_contract("turn_start", stream_result) + _perf_log.info( + "[stream_resume] filesystem_mode=%s client_platform=%s", + fs_mode, + fs_platform, + ) from app.services.token_tracking_service import start_turn @@ -1991,6 +2168,7 @@ async def stream_resume_chat( agent_config=agent_config, firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility, + filesystem_selection=filesystem_selection, ) _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 @@ -2009,7 +2187,11 @@ async def stream_resume_chat( from langgraph.types import Command config = { - "configurable": {"thread_id": str(chat_id)}, + "configurable": { + "thread_id": str(chat_id), + "request_id": request_id or "unknown", + "turn_id": stream_result.turn_id, + }, "recursion_limit": 80, } diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py new file mode 100644 index 000000000..f4adc3d73 --- /dev/null +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -0,0 +1,48 @@ +import pytest + +from app.tasks.chat.stream_new_chat import ( + StreamResult, + _contract_enforcement_active, + _evaluate_file_contract_outcome, + _tool_output_has_error, +) + +pytestmark = pytest.mark.unit + + +def test_tool_output_error_detection(): + assert _tool_output_has_error("Error: failed to write file") + assert _tool_output_has_error({"error": "boom"}) + assert _tool_output_has_error({"result": "Error: disk is full"}) + assert not _tool_output_has_error({"result": "Updated file /notes.md"}) + + +def test_file_write_contract_outcome_reasons(): + result = StreamResult(intent_detected="file_write") + passed, reason = _evaluate_file_contract_outcome(result) + assert not passed + assert reason == "no_write_attempt" + + result.write_attempted = True + passed, reason = _evaluate_file_contract_outcome(result) + assert not passed + assert reason == "write_failed" + + result.write_succeeded = True + passed, reason = _evaluate_file_contract_outcome(result) + assert not passed + assert reason == "verification_failed" + + result.verification_succeeded = True + passed, reason = _evaluate_file_contract_outcome(result) + assert passed + assert reason == "" + + +def test_contract_enforcement_local_only(): + result = StreamResult(filesystem_mode="desktop_local_folder") + assert _contract_enforcement_active(result) + + result.filesystem_mode = "cloud" + assert not _contract_enforcement_active(result) + From 5c3a327a0cedc0717f89515e6cf804c792dd1689 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:45:59 +0530 Subject: [PATCH 116/299] feat(desktop): expose agent filesystem IPC APIs --- surfsense_desktop/src/ipc/channels.ts | 4 ++++ surfsense_desktop/src/ipc/handlers.ts | 19 +++++++++++++++++++ surfsense_desktop/src/preload.ts | 8 ++++++++ 3 files changed, 31 insertions(+) diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index 6731ecbfa..177a05fb4 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -51,4 +51,8 @@ export const IPC_CHANNELS = { ANALYTICS_RESET: 'analytics:reset', ANALYTICS_CAPTURE: 'analytics:capture', ANALYTICS_GET_CONTEXT: 'analytics:get-context', + // Agent filesystem mode + AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings', + AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings', + AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root', } as const; diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 05c327436..3719a0b0f 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -36,6 +36,11 @@ import { resetUser as analyticsReset, trackEvent, } from '../modules/analytics'; +import { + getAgentFilesystemSettings, + pickAgentFilesystemRoot, + setAgentFilesystemSettings, +} from '../modules/agent-filesystem'; let authTokens: { bearer: string; refresh: string } | null = null; @@ -191,4 +196,18 @@ export function registerIpcHandlers(): void { platform: process.platform, }; }); + + ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, () => + getAgentFilesystemSettings() + ); + + ipcMain.handle( + IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, + (_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPath?: string | null }) => + setAgentFilesystemSettings(settings) + ); + + ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () => + pickAgentFilesystemRoot() + ); } diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index 3a69f3239..f75cc240e 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -101,4 +101,12 @@ contextBridge.exposeInMainWorld('electronAPI', { analyticsCapture: (event: string, properties?: Record) => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }), getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT), + // Agent filesystem mode + getAgentFilesystemSettings: () => + ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS), + setAgentFilesystemSettings: (settings: { + mode?: "cloud" | "desktop_local_folder"; + localRootPath?: string | null; + }) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings), + pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT), }); From 4899588cd701f41155962a466373b2cfc89d6123 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:46:39 +0530 Subject: [PATCH 117/299] feat(web): connect new chat UI to agent filesystem APIs --- .../src/modules/agent-filesystem.ts | 74 ++++++++++++++ .../new-chat/[[...chat_id]]/page.tsx | 20 ++++ .../components/assistant-ui/thread.tsx | 98 ++++++++++++++++++- surfsense_web/lib/apis/base-api.service.ts | 3 + surfsense_web/types/window.d.ts | 15 +++ 5 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 surfsense_desktop/src/modules/agent-filesystem.ts diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts new file mode 100644 index 000000000..44f12a465 --- /dev/null +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -0,0 +1,74 @@ +import { app, dialog } from "electron"; +import { mkdir, readFile, writeFile } from "node:fs/promises"; +import { dirname, join } from "node:path"; + +export type AgentFilesystemMode = "cloud" | "desktop_local_folder"; + +export interface AgentFilesystemSettings { + mode: AgentFilesystemMode; + localRootPath: string | null; + updatedAt: string; +} + +const SETTINGS_FILENAME = "agent-filesystem-settings.json"; + +function getSettingsPath(): string { + return join(app.getPath("userData"), SETTINGS_FILENAME); +} + +function getDefaultSettings(): AgentFilesystemSettings { + return { + mode: "cloud", + localRootPath: null, + updatedAt: new Date().toISOString(), + }; +} + +export async function getAgentFilesystemSettings(): Promise { + try { + const raw = await readFile(getSettingsPath(), "utf8"); + const parsed = JSON.parse(raw) as Partial; + if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") { + return getDefaultSettings(); + } + return { + mode: parsed.mode, + localRootPath: parsed.localRootPath ?? null, + updatedAt: parsed.updatedAt ?? new Date().toISOString(), + }; + } catch { + return getDefaultSettings(); + } +} + +export async function setAgentFilesystemSettings( + settings: Partial> +): Promise { + const current = await getAgentFilesystemSettings(); + const nextMode = + settings.mode === "cloud" || settings.mode === "desktop_local_folder" + ? settings.mode + : current.mode; + const next: AgentFilesystemSettings = { + mode: nextMode, + localRootPath: + settings.localRootPath === undefined ? current.localRootPath : settings.localRootPath, + updatedAt: new Date().toISOString(), + }; + + const settingsPath = getSettingsPath(); + await mkdir(dirname(settingsPath), { recursive: true }); + await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8"); + return next; +} + +export async function pickAgentFilesystemRoot(): Promise { + const result = await dialog.showOpenDialog({ + title: "Select local folder for Agent Filesystem", + properties: ["openDirectory"], + }); + if (result.canceled || result.filePaths.length === 0) { + return null; + } + return result.filePaths[0] ?? null; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 6c94134b7..bdb77ade2 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -46,6 +46,7 @@ import { import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useMessagesSync } from "@/hooks/use-messages-sync"; import { documentsApiService } from "@/lib/apis/documents-api.service"; +import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { getBearerToken } from "@/lib/auth-utils"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { @@ -656,6 +657,14 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const selection = await getAgentFilesystemSelection(); + if ( + selection.filesystem_mode === "desktop_local_folder" && + !selection.local_filesystem_root + ) { + toast.error("Select a local folder before using Local Folder mode."); + return; + } // Build message history for context const messageHistory = messages @@ -691,6 +700,9 @@ export default function NewChatPage() { chat_id: currentThreadId, user_query: userQuery.trim(), search_space_id: searchSpaceId, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_root: selection.local_filesystem_root, messages: messageHistory, mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, mentioned_surfsense_doc_ids: hasSurfsenseDocIds @@ -1074,6 +1086,7 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const selection = await getAgentFilesystemSelection(); const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { method: "POST", headers: { @@ -1083,6 +1096,9 @@ export default function NewChatPage() { body: JSON.stringify({ search_space_id: searchSpaceId, decisions, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_root: selection.local_filesystem_root, }), signal: controller.signal, }); @@ -1406,6 +1422,7 @@ export default function NewChatPage() { ]); try { + const selection = await getAgentFilesystemSelection(); const response = await fetch(getRegenerateUrl(threadId), { method: "POST", headers: { @@ -1416,6 +1433,9 @@ export default function NewChatPage() { search_space_id: searchSpaceId, user_query: newUserQuery || null, disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_root: selection.local_filesystem_root, }), signal: controller.signal, }); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 8d60e2c5c..094d99a29 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -94,6 +94,12 @@ import { cn } from "@/lib/utils"; const COMPOSER_PLACEHOLDER = "Ask anything, type / for prompts, type @ to mention docs"; +type ComposerFilesystemSettings = { + mode: "cloud" | "desktop_local_folder"; + localRootPath: string | null; + updatedAt: string; +}; + export const Thread: FC = () => { return ; }; @@ -362,6 +368,9 @@ const Composer: FC = () => { }, []); const electronAPI = useElectronAPI(); + const [filesystemSettings, setFilesystemSettings] = useState( + null + ); const [clipboardInitialText, setClipboardInitialText] = useState(); const clipboardLoadedRef = useRef(false); useEffect(() => { @@ -374,6 +383,48 @@ const Composer: FC = () => { }); }, [electronAPI]); + useEffect(() => { + if (!electronAPI?.getAgentFilesystemSettings) return; + let mounted = true; + electronAPI + .getAgentFilesystemSettings() + .then((settings) => { + if (!mounted) return; + setFilesystemSettings(settings); + }) + .catch(() => { + if (!mounted) return; + setFilesystemSettings({ + mode: "cloud", + localRootPath: null, + updatedAt: new Date().toISOString(), + }); + }); + return () => { + mounted = false; + }; + }, [electronAPI]); + + const handleFilesystemModeChange = useCallback( + async (mode: "cloud" | "desktop_local_folder") => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const updated = await electronAPI.setAgentFilesystemSettings({ mode }); + setFilesystemSettings(updated); + }, + [electronAPI] + ); + + const handlePickFilesystemRoot = useCallback(async () => { + if (!electronAPI?.pickAgentFilesystemRoot || !electronAPI?.setAgentFilesystemSettings) return; + const picked = await electronAPI.pickAgentFilesystemRoot(); + if (!picked) return; + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPath: picked, + }); + setFilesystemSettings(updated); + }, [electronAPI]); + const isThreadEmpty = useAuiState(({ thread }) => thread.isEmpty); const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); @@ -668,6 +719,45 @@ const Composer: FC = () => { currentUserId={currentUser?.id ?? null} members={members ?? []} /> + {electronAPI && filesystemSettings ? ( +
+ + +
+ +
+ ) : null} {showDocumentPopover && (
= ({ isBlockedByOtherUser = false group.tools.flatMap((t, i) => i === 0 ? [t.description] - : [, t.description] + : [ + , + t.description, + ] )} diff --git a/surfsense_web/lib/apis/base-api.service.ts b/surfsense_web/lib/apis/base-api.service.ts index 04e9fad54..269fd916c 100644 --- a/surfsense_web/lib/apis/base-api.service.ts +++ b/surfsense_web/lib/apis/base-api.service.ts @@ -1,4 +1,5 @@ import type { ZodType } from "zod"; +import { getClientPlatform } from "../agent-filesystem"; import { getBearerToken, handleUnauthorized, refreshAccessToken } from "../auth-utils"; import { AbortedError, @@ -75,6 +76,8 @@ class BaseApiService { const defaultOptions: RequestOptions = { headers: { Authorization: `Bearer ${this.bearerToken || ""}`, + "X-SurfSense-Client-Platform": + typeof window === "undefined" ? "web" : getClientPlatform(), }, method: "GET", responseType: ResponseType.JSON, diff --git a/surfsense_web/types/window.d.ts b/surfsense_web/types/window.d.ts index a80520684..661c0f7d6 100644 --- a/surfsense_web/types/window.d.ts +++ b/surfsense_web/types/window.d.ts @@ -41,6 +41,14 @@ interface FolderFileEntry { mtimeMs: number; } +type AgentFilesystemMode = "cloud" | "desktop_local_folder"; + +interface AgentFilesystemSettings { + mode: AgentFilesystemMode; + localRootPath: string | null; + updatedAt: string; +} + interface ElectronAPI { versions: { electron: string; @@ -125,6 +133,13 @@ interface ElectronAPI { appVersion: string; platform: string; }>; + // Agent filesystem mode + getAgentFilesystemSettings: () => Promise; + setAgentFilesystemSettings: (settings: { + mode?: AgentFilesystemMode; + localRootPath?: string | null; + }) => Promise; + pickAgentFilesystemRoot: () => Promise; } declare global { From a2ddf4765012983c7071c569d2b0cdf995542ba1 Mon Sep 17 00:00:00 2001 From: Trevin Chow Date: Thu, 23 Apr 2026 03:26:42 -0700 Subject: [PATCH 118/299] refactor(anon-chat): route upload through anonymousChatApiService Fixes #1245. Deduplicate the anonymous-chat file upload request, which was inlined verbatim in DocumentsSidebar.tsx and free-composer.tsx while anonymousChatApiService.uploadDocument already existed. Key change: service now returns a discriminated result instead of throwing on 409. Callers need to distinguish 409 (quota exceeded, -> gate to login) from other non-OK responses (real errors, -> throw). export type AnonUploadResult = | { ok: true; data: { filename: string; size_bytes: number } } | { ok: false; reason: "quota_exceeded" }; Both call sites now do: const result = await anonymousChatApiService.uploadDocument(file); if (!result.ok) { if (result.reason === "quota_exceeded") gate("upload more documents"); return; } const data = result.data; Dropped the BACKEND_URL import in both files (no longer used). Verified zero remaining /api/v1/public/anon-chat/upload references in surfsense_web/. --- .../components/free-chat/free-composer.tsx | 22 +++++-------------- .../layout/ui/sidebar/DocumentsSidebar.tsx | 22 +++++-------------- .../lib/apis/anonymous-chat-api.service.ts | 12 ++++++++-- 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/surfsense_web/components/free-chat/free-composer.tsx b/surfsense_web/components/free-chat/free-composer.tsx index 57a3e8dd9..a22d2b205 100644 --- a/surfsense_web/components/free-chat/free-composer.tsx +++ b/surfsense_web/components/free-chat/free-composer.tsx @@ -9,7 +9,7 @@ import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { useAnonymousMode } from "@/contexts/anonymous-mode"; import { useLoginGate } from "@/contexts/login-gate"; -import { BACKEND_URL } from "@/lib/env-config"; +import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { cn } from "@/lib/utils"; const ANON_ALLOWED_EXTENSIONS = new Set([ @@ -128,24 +128,12 @@ export const FreeComposer: FC = () => { } try { - const formData = new FormData(); - formData.append("file", file); - const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/upload`, { - method: "POST", - credentials: "include", - body: formData, - }); - - if (res.status === 409) { - gate("upload more documents"); + const result = await anonymousChatApiService.uploadDocument(file); + if (!result.ok) { + if (result.reason === "quota_exceeded") gate("upload more documents"); return; } - if (!res.ok) { - const body = await res.json().catch(() => ({})); - throw new Error(body.detail || `Upload failed: ${res.status}`); - } - - const data = await res.json(); + const data = result.data; if (anonMode.isAnonymous) { anonMode.setUploadedDoc({ filename: data.filename, diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index daed8747d..b7f4cff07 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -68,11 +68,11 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; +import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { BACKEND_URL } from "@/lib/env-config"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; @@ -1312,24 +1312,12 @@ function AnonymousDocumentsSidebar({ setIsUploading(true); try { - const formData = new FormData(); - formData.append("file", file); - const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/upload`, { - method: "POST", - credentials: "include", - body: formData, - }); - - if (res.status === 409) { - gate("upload more documents"); + const result = await anonymousChatApiService.uploadDocument(file); + if (!result.ok) { + if (result.reason === "quota_exceeded") gate("upload more documents"); return; } - if (!res.ok) { - const body = await res.json().catch(() => ({})); - throw new Error(body.detail || `Upload failed: ${res.status}`); - } - - const data = await res.json(); + const data = result.data; if (anonMode.isAnonymous) { anonMode.setUploadedDoc({ filename: data.filename, diff --git a/surfsense_web/lib/apis/anonymous-chat-api.service.ts b/surfsense_web/lib/apis/anonymous-chat-api.service.ts index 968f58be2..843576a50 100644 --- a/surfsense_web/lib/apis/anonymous-chat-api.service.ts +++ b/surfsense_web/lib/apis/anonymous-chat-api.service.ts @@ -12,6 +12,10 @@ import { ValidationError } from "../error"; const BASE = "/api/v1/public/anon-chat"; +export type AnonUploadResult = + | { ok: true; data: { filename: string; size_bytes: number } } + | { ok: false; reason: "quota_exceeded" }; + class AnonymousChatApiService { private baseUrl: string; @@ -71,7 +75,7 @@ class AnonymousChatApiService { }); }; - uploadDocument = async (file: File): Promise<{ filename: string; size_bytes: number }> => { + uploadDocument = async (file: File): Promise => { const formData = new FormData(); formData.append("file", file); const res = await fetch(this.fullUrl("/upload"), { @@ -79,11 +83,15 @@ class AnonymousChatApiService { credentials: "include", body: formData, }); + if (res.status === 409) { + return { ok: false, reason: "quota_exceeded" }; + } if (!res.ok) { const body = await res.json().catch(() => ({})); throw new Error(body.detail || `Upload failed: ${res.status}`); } - return res.json(); + const data = await res.json(); + return { ok: true, data }; }; getDocument = async (): Promise<{ filename: string; size_bytes: number } | null> => { From 864f6f798ab25d6c4112b5b68b8ebd9aa0abf4ec Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 17:23:38 +0530 Subject: [PATCH 119/299] feat(filesystem): enhance local file handling in editor and IPC integration --- surfsense_backend/.env.example | 3 + .../app/routes/new_chat_routes.py | 5 - surfsense_desktop/src/ipc/channels.ts | 2 + surfsense_desktop/src/ipc/handlers.ts | 25 ++++ .../src/modules/agent-filesystem.ts | 61 +++++++- surfsense_desktop/src/preload.ts | 4 + .../atoms/editor/editor-panel.atom.ts | 38 ++++- .../components/assistant-ui/markdown-text.tsx | 46 ++++++ .../components/editor-panel/editor-panel.tsx | 139 ++++++++++++++---- .../layout/ui/right-panel/RightPanel.tsx | 18 ++- surfsense_web/lib/agent-filesystem.ts | 44 ++++++ surfsense_web/types/window.d.ts | 12 ++ 12 files changed, 350 insertions(+), 47 deletions(-) create mode 100644 surfsense_web/lib/agent-filesystem.ts diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 7f6389521..86bac0aaf 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -239,6 +239,9 @@ LLAMA_CLOUD_API_KEY=llx-nnn # DAYTONA_TARGET=us # DAYTONA_SNAPSHOT_ID= +# Desktop local filesystem mode (chat file tools run against a local folder root) +# ENABLE_DESKTOP_LOCAL_FILESYSTEM=FALSE + # OPTIONAL: Add these for LangSmith Observability LANGSMITH_TRACING=true LANGSMITH_ENDPOINT=https://api.smith.langchain.com diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 5e8e24c4a..548bd1402 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -525,11 +525,6 @@ async def get_thread_messages( # Check thread-level access based on visibility await check_thread_access(session, thread, user) - filesystem_selection = _resolve_filesystem_selection( - mode=request.filesystem_mode, - client_platform=request.client_platform, - local_root=request.local_filesystem_root, - ) # Get messages with their authors and token usage loaded messages_result = await session.execute( diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index 177a05fb4..5cf6e9001 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -34,6 +34,8 @@ export const IPC_CHANNELS = { FOLDER_SYNC_SEED_MTIMES: 'folder-sync:seed-mtimes', BROWSE_FILES: 'browse:files', READ_LOCAL_FILES: 'browse:read-local-files', + READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text', + WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text', // Auth token sync across windows GET_AUTH_TOKENS: 'auth:get-tokens', SET_AUTH_TOKENS: 'auth:set-tokens', diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 3719a0b0f..cc84a46e0 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -37,6 +37,8 @@ import { trackEvent, } from '../modules/analytics'; import { + readAgentLocalFileText, + writeAgentLocalFileText, getAgentFilesystemSettings, pickAgentFilesystemRoot, setAgentFilesystemSettings, @@ -123,6 +125,29 @@ export function registerIpcHandlers(): void { readLocalFiles(paths) ); + ipcMain.handle(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, async (_event, virtualPath: string) => { + try { + const result = await readAgentLocalFileText(virtualPath); + return { ok: true, path: result.path, content: result.content }; + } catch (error) { + const message = error instanceof Error ? error.message : 'Failed to read local file'; + return { ok: false, path: virtualPath, error: message }; + } + }); + + ipcMain.handle( + IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, + async (_event, virtualPath: string, content: string) => { + try { + const result = await writeAgentLocalFileText(virtualPath, content); + return { ok: true, path: result.path }; + } catch (error) { + const message = error instanceof Error ? error.message : 'Failed to write local file'; + return { ok: false, path: virtualPath, error: message }; + } + } + ); + ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => { authTokens = tokens; }); diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index 44f12a465..9dfe79fb0 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -1,6 +1,6 @@ import { app, dialog } from "electron"; import { mkdir, readFile, writeFile } from "node:fs/promises"; -import { dirname, join } from "node:path"; +import { dirname, isAbsolute, join, relative, resolve } from "node:path"; export type AgentFilesystemMode = "cloud" | "desktop_local_folder"; @@ -72,3 +72,62 @@ export async function pickAgentFilesystemRoot(): Promise { } return result.filePaths[0] ?? null; } + +function resolveVirtualPath(rootPath: string, virtualPath: string): string { + if (!virtualPath.startsWith("/")) { + throw new Error("Path must start with '/'"); + } + const normalizedRoot = resolve(rootPath); + const relativePath = virtualPath.replace(/^\/+/, ""); + if (!relativePath) { + throw new Error("Path must refer to a file under the selected root"); + } + const absolutePath = resolve(normalizedRoot, relativePath); + const rel = relative(normalizedRoot, absolutePath); + if (!rel || rel.startsWith("..") || isAbsolute(rel)) { + throw new Error("Path escapes selected local root"); + } + return absolutePath; +} + +function toVirtualPath(rootPath: string, absolutePath: string): string { + const normalizedRoot = resolve(rootPath); + const rel = relative(normalizedRoot, absolutePath); + if (!rel || rel.startsWith("..") || isAbsolute(rel)) { + return "/"; + } + return `/${rel.replace(/\\/g, "/")}`; +} + +async function resolveCurrentRootPath(): Promise { + const settings = await getAgentFilesystemSettings(); + if (!settings.localRootPath) { + throw new Error("No local filesystem root selected"); + } + return settings.localRootPath; +} + +export async function readAgentLocalFileText( + virtualPath: string +): Promise<{ path: string; content: string }> { + const rootPath = await resolveCurrentRootPath(); + const absolutePath = resolveVirtualPath(rootPath, virtualPath); + const content = await readFile(absolutePath, "utf8"); + return { + path: toVirtualPath(rootPath, absolutePath), + content, + }; +} + +export async function writeAgentLocalFileText( + virtualPath: string, + content: string +): Promise<{ path: string }> { + const rootPath = await resolveCurrentRootPath(); + const absolutePath = resolveVirtualPath(rootPath, virtualPath); + await mkdir(dirname(absolutePath), { recursive: true }); + await writeFile(absolutePath, content, "utf8"); + return { + path: toVirtualPath(rootPath, absolutePath), + }; +} diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index f75cc240e..9fc213bfa 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -71,6 +71,10 @@ contextBridge.exposeInMainWorld('electronAPI', { // Browse files via native dialog browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES), readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths), + readAgentLocalFileText: (virtualPath: string) => + ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath), + writeAgentLocalFileText: (virtualPath: string, content: string) => + ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content), // Auth token sync across windows getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS), diff --git a/surfsense_web/atoms/editor/editor-panel.atom.ts b/surfsense_web/atoms/editor/editor-panel.atom.ts index 7dc6add28..28563e7d3 100644 --- a/surfsense_web/atoms/editor/editor-panel.atom.ts +++ b/surfsense_web/atoms/editor/editor-panel.atom.ts @@ -3,14 +3,18 @@ import { rightPanelCollapsedAtom, rightPanelTabAtom } from "@/atoms/layout/right interface EditorPanelState { isOpen: boolean; + kind: "document" | "local_file"; documentId: number | null; + localFilePath: string | null; searchSpaceId: number | null; title: string | null; } const initialState: EditorPanelState = { isOpen: false, + kind: "document", documentId: null, + localFilePath: null, searchSpaceId: null, title: null, }; @@ -26,20 +30,38 @@ export const openEditorPanelAtom = atom( ( get, set, - { - documentId, - searchSpaceId, - title, - }: { documentId: number; searchSpaceId: number; title?: string } + payload: + | { documentId: number; searchSpaceId: number; title?: string; kind?: "document" } + | { + kind: "local_file"; + localFilePath: string; + title?: string; + searchSpaceId?: number; + } ) => { if (!get(editorPanelAtom).isOpen) { set(preEditorCollapsedAtom, get(rightPanelCollapsedAtom)); } + if (payload.kind === "local_file") { + set(editorPanelAtom, { + isOpen: true, + kind: "local_file", + documentId: null, + localFilePath: payload.localFilePath, + searchSpaceId: payload.searchSpaceId ?? null, + title: payload.title ?? null, + }); + set(rightPanelTabAtom, "editor"); + set(rightPanelCollapsedAtom, false); + return; + } set(editorPanelAtom, { isOpen: true, - documentId, - searchSpaceId, - title: title ?? null, + kind: "document", + documentId: payload.documentId, + localFilePath: null, + searchSpaceId: payload.searchSpaceId, + title: payload.title ?? null, }); set(rightPanelTabAtom, "editor"); set(rightPanelCollapsedAtom, false); diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 9d0c8a9ed..a2ce30111 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -7,16 +7,20 @@ import { unstable_memoizeMarkdownComponents as memoizeMarkdownComponents, useIsMarkdownCodeBlock, } from "@assistant-ui/react-markdown"; +import { useSetAtom } from "jotai"; import { ExternalLinkIcon } from "lucide-react"; import dynamic from "next/dynamic"; +import { useParams } from "next/navigation"; import { useTheme } from "next-themes"; import { memo, type ReactNode } from "react"; import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; +import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { useElectronAPI } from "@/hooks/use-platform"; import { Skeleton } from "@/components/ui/skeleton"; import { Table, @@ -222,6 +226,12 @@ function extractDomain(url: string): string { } } +const LOCAL_FILE_PATH_REGEX = /^\/(?:[^/\s`]+\/)*[^/\s`]+\.[^/\s`]+$/; + +function isVirtualFilePathToken(value: string): boolean { + return LOCAL_FILE_PATH_REGEX.test(value); +} + function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { if (!src) return null; @@ -392,7 +402,43 @@ const defaultComponents = memoizeMarkdownComponents({ code: function Code({ className, children, ...props }) { const isCodeBlock = useIsMarkdownCodeBlock(); const { resolvedTheme } = useTheme(); + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const params = useParams(); + const electronAPI = useElectronAPI(); if (!isCodeBlock) { + const inlineValue = String(children ?? "").trim(); + const isLocalPath = + !!electronAPI && isVirtualFilePathToken(inlineValue) && !inlineValue.startsWith("//"); + const displayLocalPath = inlineValue.replace(/^\/+/, ""); + const searchSpaceIdParam = params?.search_space_id; + const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) + ? Number(searchSpaceIdParam[0]) + : Number(searchSpaceIdParam); + if (isLocalPath) { + return ( + + ); + } return ( void; }) { + const electronAPI = useElectronAPI(); const [editorDoc, setEditorDoc] = useState(null); const [isLoading, setIsLoading] = useState(true); const [error, setError] = useState(null); @@ -75,6 +81,7 @@ export function EditorPanelContent({ const initialLoadDone = useRef(false); const changeCountRef = useRef(0); const [displayTitle, setDisplayTitle] = useState(title || "Untitled"); + const isLocalFileMode = kind === "local_file"; const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > LARGE_DOCUMENT_THRESHOLD; @@ -88,13 +95,40 @@ export function EditorPanelContent({ changeCountRef.current = 0; const doFetch = async () => { - const token = getBearerToken(); - if (!token) { - redirectToLogin(); - return; - } - try { + if (isLocalFileMode) { + if (!localFilePath) { + throw new Error("Missing local file path"); + } + if (!electronAPI?.readAgentLocalFileText) { + throw new Error("Local file editor is available only in desktop mode."); + } + const readResult = await electronAPI.readAgentLocalFileText(localFilePath); + if (!readResult.ok) { + throw new Error(readResult.error || "Failed to read local file"); + } + const inferredTitle = localFilePath.split("/").pop() || localFilePath; + const content: EditorContent = { + document_id: -1, + title: inferredTitle, + document_type: "NOTE", + source_markdown: readResult.content, + }; + markdownRef.current = content.source_markdown; + setDisplayTitle(title || inferredTitle); + setEditorDoc(content); + initialLoadDone.current = true; + return; + } + if (!documentId || !searchSpaceId) { + throw new Error("Missing document context"); + } + const token = getBearerToken(); + if (!token) { + redirectToLogin(); + return; + } + const url = new URL( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); @@ -136,7 +170,7 @@ export function EditorPanelContent({ doFetch().catch(() => {}); return () => controller.abort(); - }, [documentId, searchSpaceId, title]); + }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId, title]); const handleMarkdownChange = useCallback((md: string) => { markdownRef.current = md; @@ -147,15 +181,38 @@ export function EditorPanelContent({ }, []); const handleSave = useCallback(async () => { - const token = getBearerToken(); - if (!token) { - toast.error("Please login to save"); - redirectToLogin(); - return; - } - setSaving(true); try { + if (isLocalFileMode) { + if (!localFilePath) { + throw new Error("Missing local file path"); + } + if (!electronAPI?.writeAgentLocalFileText) { + throw new Error("Local file editor is available only in desktop mode."); + } + const writeResult = await electronAPI.writeAgentLocalFileText( + localFilePath, + markdownRef.current + ); + if (!writeResult.ok) { + throw new Error(writeResult.error || "Failed to save local file"); + } + setEditorDoc((prev) => + prev ? { ...prev, source_markdown: markdownRef.current } : prev + ); + setEditedMarkdown(null); + toast.success("File saved"); + return; + } + if (!searchSpaceId || !documentId) { + throw new Error("Missing document context"); + } + const token = getBearerToken(); + if (!token) { + toast.error("Please login to save"); + redirectToLogin(); + return; + } const response = await authenticatedFetch( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, { @@ -181,10 +238,11 @@ export function EditorPanelContent({ } finally { setSaving(false); } - }, [documentId, searchSpaceId]); + }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId]); const isEditableType = editorDoc - ? EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "") && !isLargeDocument + ? (isLocalFileMode || EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && + !isLargeDocument : false; return ( @@ -197,7 +255,7 @@ export function EditorPanelContent({ )}
- {editorDoc?.document_type && ( + {!isLocalFileMode && editorDoc?.document_type && documentId && ( )} {onClose && ( @@ -234,7 +292,7 @@ export function EditorPanelContent({

- ) : isLargeDocument ? ( + ) : isLargeDocument && !isLocalFileMode ? (
@@ -252,6 +310,9 @@ export function EditorPanelContent({ onClick={async () => { setDownloading(true); try { + if (!searchSpaceId || !documentId) { + throw new Error("Missing document context"); + } const response = await authenticatedFetch( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, { method: "GET" } @@ -289,7 +350,7 @@ export function EditorPanelContent({
) : isEditableType ? ( document.removeEventListener("keydown", handleKeyDown); }, [closePanel]); - if (!panelState.isOpen || !panelState.documentId || !panelState.searchSpaceId) return null; + const hasTarget = + panelState.kind === "document" + ? !!panelState.documentId && !!panelState.searchSpaceId + : !!panelState.localFilePath; + if (!panelState.isOpen || !hasTarget) return null; return (
@@ -342,7 +409,11 @@ function MobileEditorDrawer() { const panelState = useAtomValue(editorPanelAtom); const closePanel = useSetAtom(closeEditorPanelAtom); - if (!panelState.documentId || !panelState.searchSpaceId) return null; + const hasTarget = + panelState.kind === "document" + ? !!panelState.documentId && !!panelState.searchSpaceId + : !!panelState.localFilePath; + if (!hasTarget) return null; return ( {panelState.title || "Editor"}
@@ -373,8 +446,12 @@ function MobileEditorDrawer() { export function EditorPanel() { const panelState = useAtomValue(editorPanelAtom); const isDesktop = useMediaQuery("(min-width: 1024px)"); + const hasTarget = + panelState.kind === "document" + ? !!panelState.documentId && !!panelState.searchSpaceId + : !!panelState.localFilePath; - if (!panelState.isOpen || !panelState.documentId) return null; + if (!panelState.isOpen || !hasTarget) return null; if (isDesktop) { return ; @@ -386,8 +463,12 @@ export function EditorPanel() { export function MobileEditorPanel() { const panelState = useAtomValue(editorPanelAtom); const isDesktop = useMediaQuery("(min-width: 1024px)"); + const hasTarget = + panelState.kind === "document" + ? !!panelState.documentId && !!panelState.searchSpaceId + : !!panelState.localFilePath; - if (isDesktop || !panelState.isOpen || !panelState.documentId) return null; + if (isDesktop || !panelState.isOpen || !hasTarget) return null; return ; } diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index febae35d3..f6debed34 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -70,7 +70,11 @@ export function RightPanelExpandButton() { const editorState = useAtomValue(editorPanelAtom); const hitlEditState = useAtomValue(hitlEditPanelAtom); const reportOpen = reportState.isOpen && !!reportState.reportId; - const editorOpen = editorState.isOpen && !!editorState.documentId; + const editorOpen = + editorState.isOpen && + (editorState.kind === "document" + ? !!editorState.documentId + : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen; @@ -110,7 +114,11 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { const documentsOpen = documentsPanel?.open ?? false; const reportOpen = reportState.isOpen && !!reportState.reportId; - const editorOpen = editorState.isOpen && !!editorState.documentId; + const editorOpen = + editorState.isOpen && + (editorState.kind === "document" + ? !!editorState.documentId + : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; useEffect(() => { @@ -179,8 +187,10 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { {effectiveTab === "editor" && editorOpen && (
diff --git a/surfsense_web/lib/agent-filesystem.ts b/surfsense_web/lib/agent-filesystem.ts new file mode 100644 index 000000000..6bfb5d131 --- /dev/null +++ b/surfsense_web/lib/agent-filesystem.ts @@ -0,0 +1,44 @@ +export type AgentFilesystemMode = "cloud" | "desktop_local_folder"; +export type ClientPlatform = "web" | "desktop"; + +export interface AgentFilesystemSelection { + filesystem_mode: AgentFilesystemMode; + client_platform: ClientPlatform; + local_filesystem_root?: string; +} + +const DEFAULT_SELECTION: AgentFilesystemSelection = { + filesystem_mode: "cloud", + client_platform: "web", +}; + +export function getClientPlatform(): ClientPlatform { + if (typeof window === "undefined") return "web"; + return window.electronAPI ? "desktop" : "web"; +} + +export async function getAgentFilesystemSelection(): Promise { + const platform = getClientPlatform(); + if (platform !== "desktop" || !window.electronAPI?.getAgentFilesystemSettings) { + return { ...DEFAULT_SELECTION, client_platform: platform }; + } + try { + const settings = await window.electronAPI.getAgentFilesystemSettings(); + if (settings.mode === "desktop_local_folder" && settings.localRootPath) { + return { + filesystem_mode: "desktop_local_folder", + client_platform: "desktop", + local_filesystem_root: settings.localRootPath, + }; + } + return { + filesystem_mode: "cloud", + client_platform: "desktop", + }; + } catch { + return { + filesystem_mode: "cloud", + client_platform: "desktop", + }; + } +} diff --git a/surfsense_web/types/window.d.ts b/surfsense_web/types/window.d.ts index 661c0f7d6..fe80ef8c0 100644 --- a/surfsense_web/types/window.d.ts +++ b/surfsense_web/types/window.d.ts @@ -49,6 +49,13 @@ interface AgentFilesystemSettings { updatedAt: string; } +interface LocalTextFileResult { + ok: boolean; + path: string; + content?: string; + error?: string; +} + interface ElectronAPI { versions: { electron: string; @@ -102,6 +109,11 @@ interface ElectronAPI { // Browse files/folders via native dialogs browseFiles: () => Promise; readLocalFiles: (paths: string[]) => Promise; + readAgentLocalFileText: (virtualPath: string) => Promise; + writeAgentLocalFileText: ( + virtualPath: string, + content: string + ) => Promise; // Auth token sync across windows getAuthTokens: () => Promise<{ bearer: string; refresh: string } | null>; setAuthTokens: (bearer: string, refresh: string) => Promise; From bbc1c76c0d75432a85ade3cc2654d2ff0027e414 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:00:51 +0530 Subject: [PATCH 120/299] feat(editor): integrate Monaco Editor for local file editing and enhance language inference --- .../components/editor-panel/editor-panel.tsx | 20 +++++++ .../components/editor/local-file-monaco.tsx | 56 +++++++++++++++++++ surfsense_web/lib/editor-language.ts | 34 +++++++++++ surfsense_web/package.json | 2 + surfsense_web/pnpm-lock.yaml | 54 ++++++++++++++++++ 5 files changed, 166 insertions(+) create mode 100644 surfsense_web/components/editor/local-file-monaco.tsx create mode 100644 surfsense_web/lib/editor-language.ts diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index f7829d0cb..081359719 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -7,6 +7,7 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { VersionHistoryButton } from "@/components/documents/version-history"; +import { LocalFileMonaco } from "@/components/editor/local-file-monaco"; import { MarkdownViewer } from "@/components/markdown-viewer"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; @@ -14,6 +15,7 @@ import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/u import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; +import { inferMonacoLanguageFromPath } from "@/lib/editor-language"; const PlateEditor = dynamic( () => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })), @@ -77,6 +79,7 @@ export function EditorPanelContent({ const [downloading, setDownloading] = useState(false); const [editedMarkdown, setEditedMarkdown] = useState(null); + const [localFileContent, setLocalFileContent] = useState(""); const markdownRef = useRef(""); const initialLoadDone = useRef(false); const changeCountRef = useRef(0); @@ -91,6 +94,7 @@ export function EditorPanelContent({ setError(null); setEditorDoc(null); setEditedMarkdown(null); + setLocalFileContent(""); initialLoadDone.current = false; changeCountRef.current = 0; @@ -115,6 +119,7 @@ export function EditorPanelContent({ source_markdown: readResult.content, }; markdownRef.current = content.source_markdown; + setLocalFileContent(content.source_markdown); setDisplayTitle(title || inferredTitle); setEditorDoc(content); initialLoadDone.current = true; @@ -244,6 +249,7 @@ export function EditorPanelContent({ ? (isLocalFileMode || EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && !isLargeDocument : false; + const localFileLanguage = inferMonacoLanguageFromPath(localFilePath); return ( <> @@ -348,6 +354,20 @@ export function EditorPanelContent({
+ ) : isLocalFileMode ? ( +
+ { + markdownRef.current = next; + setLocalFileContent(next); + if (!initialLoadDone.current) return; + setEditedMarkdown(next === (editorDoc?.source_markdown ?? "") ? null : next); + }} + /> +
) : isEditableType ? ( import("@monaco-editor/react"), { + ssr: false, +}); + +interface LocalFileMonacoProps { + filePath: string; + language: string; + value: string; + onChange: (next: string) => void; +} + +export function LocalFileMonaco({ filePath, language, value, onChange }: LocalFileMonacoProps) { + const { resolvedTheme } = useTheme(); + + return ( +
+ onChange(next ?? "")} + options={{ + automaticLayout: true, + minimap: { enabled: false }, + lineNumbers: "on", + lineNumbersMinChars: 3, + lineDecorationsWidth: 12, + glyphMargin: false, + folding: true, + overviewRulerLanes: 0, + hideCursorInOverviewRuler: true, + scrollBeyondLastLine: false, + wordWrap: "off", + scrollbar: { + vertical: "hidden", + horizontal: "hidden", + alwaysConsumeMouseWheel: false, + }, + tabSize: 2, + insertSpaces: true, + fontSize: 12, + fontFamily: + "ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, Liberation Mono, monospace", + renderWhitespace: "selection", + smoothScrolling: true, + }} + /> +
+ ); +} diff --git a/surfsense_web/lib/editor-language.ts b/surfsense_web/lib/editor-language.ts new file mode 100644 index 000000000..17227c15d --- /dev/null +++ b/surfsense_web/lib/editor-language.ts @@ -0,0 +1,34 @@ +const EXTENSION_TO_MONACO_LANGUAGE: Record = { + css: "css", + csv: "plaintext", + cjs: "javascript", + html: "html", + htm: "html", + ini: "ini", + js: "javascript", + json: "json", + markdown: "markdown", + md: "markdown", + mjs: "javascript", + py: "python", + sql: "sql", + toml: "plaintext", + ts: "typescript", + tsx: "typescript", + xml: "xml", + yaml: "yaml", + yml: "yaml", +}; + +export function inferMonacoLanguageFromPath(filePath: string | null | undefined): string { + if (!filePath) return "plaintext"; + + const fileName = filePath.split("/").pop() ?? filePath; + const extensionIndex = fileName.lastIndexOf("."); + if (extensionIndex <= 0 || extensionIndex === fileName.length - 1) { + return "plaintext"; + } + + const extension = fileName.slice(extensionIndex + 1).toLowerCase(); + return EXTENSION_TO_MONACO_LANGUAGE[extension] ?? "plaintext"; +} diff --git a/surfsense_web/package.json b/surfsense_web/package.json index a98c21f83..41175daeb 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -28,6 +28,7 @@ "@babel/standalone": "^7.29.2", "@hookform/resolvers": "^5.2.2", "@marsidev/react-turnstile": "^1.5.0", + "@monaco-editor/react": "^4.7.0", "@number-flow/react": "^0.5.10", "@platejs/autoformat": "^52.0.11", "@platejs/basic-nodes": "^52.0.11", @@ -106,6 +107,7 @@ "lenis": "^1.3.17", "lowlight": "^3.3.0", "lucide-react": "^0.577.0", + "monaco-editor": "^0.55.1", "motion": "^12.23.22", "next": "^16.1.0", "next-intl": "^4.6.1", diff --git a/surfsense_web/pnpm-lock.yaml b/surfsense_web/pnpm-lock.yaml index 1c3dd61e0..b1730e842 100644 --- a/surfsense_web/pnpm-lock.yaml +++ b/surfsense_web/pnpm-lock.yaml @@ -29,6 +29,9 @@ importers: '@marsidev/react-turnstile': specifier: ^1.5.0 version: 1.5.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@monaco-editor/react': + specifier: ^4.7.0 + version: 4.7.0(monaco-editor@0.55.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@number-flow/react': specifier: ^0.5.10 version: 0.5.14(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -263,6 +266,9 @@ importers: lucide-react: specifier: ^0.577.0 version: 0.577.0(react@19.2.4) + monaco-editor: + specifier: ^0.55.1 + version: 0.55.1 motion: specifier: ^12.23.22 version: 12.34.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -1980,6 +1986,16 @@ packages: peerDependencies: mediabunny: ^1.0.0 + '@monaco-editor/loader@1.7.0': + resolution: {integrity: sha512-gIwR1HrJrrx+vfyOhYmCZ0/JcWqG5kbfG7+d3f/C1LXk2EvzAbHSg3MQ5lO2sMlo9izoAZ04shohfKLVT6crVA==} + + '@monaco-editor/react@4.7.0': + resolution: {integrity: sha512-cyzXQCtO47ydzxpQtCGSQGOC8Gk3ZUeBXFAxD+CWXYFo5OqZyZUonFl0DwUlTyAfRHntBfw2p3w4s9R6oe1eCA==} + peerDependencies: + monaco-editor: '>= 0.25.0 < 1' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + '@napi-rs/canvas-android-arm64@0.1.97': resolution: {integrity: sha512-V1c/WVw+NzH8vk7ZK/O8/nyBSCQimU8sfMsB/9qeSvdkGKNU7+mxy/bIF0gTgeBFmHpj30S4E9WHMSrxXGQuVQ==} engines: {node: '>= 10'} @@ -5368,6 +5384,9 @@ packages: resolution: {integrity: sha512-cgwlv/1iFQiFnU96XXgROh8xTeetsnJiDsTc7TYCLFd9+/WNkIqPTxiM/8pSd8VIrhXGTf1Ny1q1hquVqDJB5w==} engines: {node: '>= 4'} + dompurify@3.2.7: + resolution: {integrity: sha512-WhL/YuveyGXJaerVlMYGWhvQswa7myDG17P7Vu65EWC05o8vfeNbvNf4d/BOvH99+ZW+LlQsc1GDKMa1vNK6dw==} + dompurify@3.3.1: resolution: {integrity: sha512-qkdCKzLNtrgPFP1Vo+98FRzJnBRGe4ffyCea9IwHB1fyxPOeNTHpLKYGd4Uk9xvNoH0ZoOjwZxNptyMwqrId1Q==} @@ -6745,6 +6764,11 @@ packages: markdown-table@3.0.4: resolution: {integrity: sha512-wiYz4+JrLyb/DqW2hkFJxP7Vd7JuTDm77fvbM8VfEQdmSMqcImWeeRbHwZjBjIFki/VaMK2BhFi7oUUZeM5bqw==} + marked@14.0.0: + resolution: {integrity: sha512-uIj4+faQ+MgHgwUW1l2PsPglZLOLOT1uErt06dAPtx2kjteLAkbsd/0FiYg/MGS+i7ZKLb7w2WClxHkzOOuryQ==} + engines: {node: '>= 18'} + hasBin: true + marked@15.0.12: resolution: {integrity: sha512-8dD6FusOQSrpv9Z1rdNMdlSgQOIP880DHqnohobOmYLElGEqAL/JvxvuxZO16r4HtjTlfPRDC1hbvxC9dPN2nA==} engines: {node: '>= 18'} @@ -6965,6 +6989,9 @@ packages: module-details-from-path@1.0.4: resolution: {integrity: sha512-EGWKgxALGMgzvxYF1UyGTy0HXX/2vHLkw6+NvDKW2jypWbHpjQuj4UMcqQWXHERJhVGKikolT06G3bcKe4fi7w==} + monaco-editor@0.55.1: + resolution: {integrity: sha512-jz4x+TJNFHwHtwuV9vA9rMujcZRb0CEilTEwG2rRSpe/A7Jdkuj8xPKttCgOh+v/lkHy7HsZ64oj+q3xoAFl9A==} + motion-dom@12.34.3: resolution: {integrity: sha512-sYgFe+pR9aIM7o4fhs2aXtOI+oqlUd33N9Yoxcgo1Fv7M20sRkHtCmzE/VRNIcq7uNJ+qio+Xubt1FXH3pQ+eQ==} @@ -7943,6 +7970,9 @@ packages: stable-hash@0.0.5: resolution: {integrity: sha512-+L3ccpzibovGXFK+Ap/f8LOS0ahMrHTf3xu7mMLSpEGU0EO9ucaysSylKo9eRDFNhWve/y275iPmIZ4z39a9iA==} + state-local@1.0.7: + resolution: {integrity: sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==} + stop-iteration-iterator@1.1.0: resolution: {integrity: sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==} engines: {node: '>= 0.4'} @@ -10050,6 +10080,17 @@ snapshots: dependencies: mediabunny: 1.39.2 + '@monaco-editor/loader@1.7.0': + dependencies: + state-local: 1.0.7 + + '@monaco-editor/react@4.7.0(monaco-editor@0.55.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + dependencies: + '@monaco-editor/loader': 1.7.0 + monaco-editor: 0.55.1 + react: 19.2.4 + react-dom: 19.2.4(react@19.2.4) + '@napi-rs/canvas-android-arm64@0.1.97': optional: true @@ -13748,6 +13789,10 @@ snapshots: dependencies: domelementtype: 2.3.0 + dompurify@3.2.7: + optionalDependencies: + '@types/trusted-types': 2.0.7 + dompurify@3.3.1: optionalDependencies: '@types/trusted-types': 2.0.7 @@ -15327,6 +15372,8 @@ snapshots: markdown-table@3.0.4: {} + marked@14.0.0: {} + marked@15.0.12: {} marked@17.0.3: {} @@ -15822,6 +15869,11 @@ snapshots: module-details-from-path@1.0.4: {} + monaco-editor@0.55.1: + dependencies: + dompurify: 3.2.7 + marked: 14.0.0 + motion-dom@12.34.3: dependencies: motion-utils: 12.29.2 @@ -17073,6 +17125,8 @@ snapshots: stable-hash@0.0.5: {} + state-local@1.0.7: {} + stop-iteration-iterator@1.1.0: dependencies: es-errors: 1.3.0 From d397fec54fd829561967c524ad08638eed263531 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:21:50 +0530 Subject: [PATCH 121/299] feat(editor): add SourceCodeEditor component for enhanced code editing experience --- .../components/editor-panel/editor-panel.tsx | 13 +++++---- ...file-monaco.tsx => source-code-editor.tsx} | 28 +++++++++++++++---- 2 files changed, 30 insertions(+), 11 deletions(-) rename surfsense_web/components/editor/{local-file-monaco.tsx => source-code-editor.tsx} (69%) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 081359719..137ece5e2 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -7,7 +7,7 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { VersionHistoryButton } from "@/components/documents/version-history"; -import { LocalFileMonaco } from "@/components/editor/local-file-monaco"; +import { SourceCodeEditor } from "@/components/editor/source-code-editor"; import { MarkdownViewer } from "@/components/markdown-viewer"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; @@ -35,6 +35,7 @@ interface EditorContent { } const EDITABLE_DOCUMENT_TYPES = new Set(["FILE", "NOTE"]); +type EditorRenderMode = "rich_markdown" | "source_code"; function EditorPanelSkeleton() { return ( @@ -85,6 +86,7 @@ export function EditorPanelContent({ const changeCountRef = useRef(0); const [displayTitle, setDisplayTitle] = useState(title || "Untitled"); const isLocalFileMode = kind === "local_file"; + const editorRenderMode: EditorRenderMode = isLocalFileMode ? "source_code" : "rich_markdown"; const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > LARGE_DOCUMENT_THRESHOLD; @@ -246,7 +248,8 @@ export function EditorPanelContent({ }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId]); const isEditableType = editorDoc - ? (isLocalFileMode || EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && + ? (editorRenderMode === "source_code" || + EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && !isLargeDocument : false; const localFileLanguage = inferMonacoLanguageFromPath(localFilePath); @@ -354,10 +357,10 @@ export function EditorPanelContent({
- ) : isLocalFileMode ? ( + ) : editorRenderMode === "source_code" ? (
- { diff --git a/surfsense_web/components/editor/local-file-monaco.tsx b/surfsense_web/components/editor/source-code-editor.tsx similarity index 69% rename from surfsense_web/components/editor/local-file-monaco.tsx rename to surfsense_web/components/editor/source-code-editor.tsx index b27203341..7bb7bee35 100644 --- a/surfsense_web/components/editor/local-file-monaco.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -2,29 +2,44 @@ import dynamic from "next/dynamic"; import { useTheme } from "next-themes"; +import { Spinner } from "@/components/ui/spinner"; const MonacoEditor = dynamic(() => import("@monaco-editor/react"), { ssr: false, }); -interface LocalFileMonacoProps { - filePath: string; - language: string; +interface SourceCodeEditorProps { value: string; onChange: (next: string) => void; + path?: string; + language?: string; + readOnly?: boolean; + fontSize?: number; } -export function LocalFileMonaco({ filePath, language, value, onChange }: LocalFileMonacoProps) { +export function SourceCodeEditor({ + value, + onChange, + path, + language = "plaintext", + readOnly = false, + fontSize = 12, +}: SourceCodeEditorProps) { const { resolvedTheme } = useTheme(); return (
onChange(next ?? "")} + loading={ +
+ +
+ } options={{ automaticLayout: true, minimap: { enabled: false }, @@ -44,11 +59,12 @@ export function LocalFileMonaco({ filePath, language, value, onChange }: LocalFi }, tabSize: 2, insertSpaces: true, - fontSize: 12, + fontSize, fontFamily: "ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, Liberation Mono, monospace", renderWhitespace: "selection", smoothScrolling: true, + readOnly, }} />
From 3f203f8c49cace8010d88d9dcf812852196fae16 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:29:32 +0530 Subject: [PATCH 122/299] feat(editor): implement auto-save functionality and manual save command in SourceCodeEditor --- .../components/editor-panel/editor-panel.tsx | 12 ++-- .../components/editor/source-code-editor.tsx | 55 +++++++++++++++++++ .../layout/ui/right-panel/RightPanel.tsx | 2 +- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 137ece5e2..739428df3 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -187,7 +187,7 @@ export function EditorPanelContent({ setEditedMarkdown(md); }, []); - const handleSave = useCallback(async () => { + const handleSave = useCallback(async (options?: { silent?: boolean }) => { setSaving(true); try { if (isLocalFileMode) { @@ -197,18 +197,18 @@ export function EditorPanelContent({ if (!electronAPI?.writeAgentLocalFileText) { throw new Error("Local file editor is available only in desktop mode."); } + const contentToSave = markdownRef.current; const writeResult = await electronAPI.writeAgentLocalFileText( localFilePath, - markdownRef.current + contentToSave ); if (!writeResult.ok) { throw new Error(writeResult.error || "Failed to save local file"); } setEditorDoc((prev) => - prev ? { ...prev, source_markdown: markdownRef.current } : prev + prev ? { ...prev, source_markdown: contentToSave } : prev ); - setEditedMarkdown(null); - toast.success("File saved"); + setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current); return; } if (!searchSpaceId || !documentId) { @@ -363,6 +363,8 @@ export function EditorPanelContent({ path={localFilePath ?? "local-file.txt"} language={localFileLanguage} value={localFileContent} + onSave={() => handleSave({ silent: true })} + saveMode="auto" onChange={(next) => { markdownRef.current = next; setLocalFileContent(next); diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index 7bb7bee35..bd3728721 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -1,6 +1,7 @@ "use client"; import dynamic from "next/dynamic"; +import { useEffect, useRef } from "react"; import { useTheme } from "next-themes"; import { Spinner } from "@/components/ui/spinner"; @@ -15,6 +16,9 @@ interface SourceCodeEditorProps { language?: string; readOnly?: boolean; fontSize?: number; + onSave?: () => Promise | void; + saveMode?: "manual" | "auto" | "both"; + autoSaveDelayMs?: number; } export function SourceCodeEditor({ @@ -24,8 +28,50 @@ export function SourceCodeEditor({ language = "plaintext", readOnly = false, fontSize = 12, + onSave, + saveMode = "manual", + autoSaveDelayMs = 800, }: SourceCodeEditorProps) { const { resolvedTheme } = useTheme(); + const saveTimerRef = useRef | null>(null); + const onSaveRef = useRef(onSave); + const skipNextAutoSaveRef = useRef(true); + + useEffect(() => { + onSaveRef.current = onSave; + }, [onSave]); + + useEffect(() => { + skipNextAutoSaveRef.current = true; + }, [path]); + + useEffect(() => { + if (readOnly || !onSaveRef.current) return; + if (saveMode !== "auto" && saveMode !== "both") return; + + if (skipNextAutoSaveRef.current) { + skipNextAutoSaveRef.current = false; + return; + } + + if (saveTimerRef.current) { + clearTimeout(saveTimerRef.current); + } + + saveTimerRef.current = setTimeout(() => { + void onSaveRef.current?.(); + saveTimerRef.current = null; + }, autoSaveDelayMs); + + return () => { + if (saveTimerRef.current) { + clearTimeout(saveTimerRef.current); + saveTimerRef.current = null; + } + }; + }, [autoSaveDelayMs, readOnly, saveMode, value]); + + const isManualSaveEnabled = !!onSave && !readOnly && (saveMode === "manual" || saveMode === "both"); return (
@@ -40,6 +86,12 @@ export function SourceCodeEditor({
} + onMount={(editor, monaco) => { + if (!isManualSaveEnabled) return; + editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, () => { + void onSaveRef.current?.(); + }); + }} options={{ automaticLayout: true, minimap: { enabled: false }, @@ -51,6 +103,9 @@ export function SourceCodeEditor({ overviewRulerLanes: 0, hideCursorInOverviewRuler: true, scrollBeyondLastLine: false, + renderLineHighlight: "none", + selectionHighlight: false, + occurrencesHighlight: "off", wordWrap: "off", scrollbar: { vertical: "hidden", diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index f6debed34..2394480b2 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -53,7 +53,7 @@ function CollapseButton({ onClick }: { onClick: () => void }) { Collapse panel - Collapse panel + Collapse panel ); } From fe9ffa1413557ce61244135728030ce72eca99ab Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:39:35 +0530 Subject: [PATCH 123/299] refactor(editor): improve SourceCodeEditor styling and enhance scrollbar behavior --- .../components/editor-panel/editor-panel.tsx | 5 +---- .../components/editor/source-code-editor.tsx | 13 ++++++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 739428df3..30dcdeb2c 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -256,12 +256,9 @@ export function EditorPanelContent({ return ( <> -
+

{displayTitle}

- {isEditableType && editedMarkdown !== null && ( -

Unsaved changes

- )}
{!isLocalFileMode && editorDoc?.document_type && documentId && ( diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index bd3728721..2c1f52989 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -74,7 +74,7 @@ export function SourceCodeEditor({ const isManualSaveEnabled = !!onSave && !readOnly && (saveMode === "manual" || saveMode === "both"); return ( -
+
Date: Thu, 23 Apr 2026 19:25:59 +0530 Subject: [PATCH 124/299] refactor(editor): remove auto-save functionality and simplify SourceCodeEditor props --- .../components/editor-panel/editor-panel.tsx | 168 ++++++++++++++++-- .../components/editor/source-code-editor.tsx | 84 +++++---- .../layout/ui/right-panel/RightPanel.tsx | 2 +- 3 files changed, 198 insertions(+), 56 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 30dcdeb2c..b83c4b1d7 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -1,7 +1,17 @@ "use client"; import { useAtomValue, useSetAtom } from "jotai"; -import { Download, FileQuestionMark, FileText, Loader2, RefreshCw, XIcon } from "lucide-react"; +import { + Check, + Copy, + Download, + FileQuestionMark, + FileText, + Loader2, + Pencil, + RefreshCw, + XIcon, +} from "lucide-react"; import dynamic from "next/dynamic"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; @@ -78,10 +88,13 @@ export function EditorPanelContent({ const [error, setError] = useState(null); const [saving, setSaving] = useState(false); const [downloading, setDownloading] = useState(false); + const [isSourceEditing, setIsSourceEditing] = useState(false); const [editedMarkdown, setEditedMarkdown] = useState(null); const [localFileContent, setLocalFileContent] = useState(""); + const [hasCopied, setHasCopied] = useState(false); const markdownRef = useRef(""); + const copyResetTimeoutRef = useRef | null>(null); const initialLoadDone = useRef(false); const changeCountRef = useRef(0); const [displayTitle, setDisplayTitle] = useState(title || "Untitled"); @@ -97,6 +110,8 @@ export function EditorPanelContent({ setEditorDoc(null); setEditedMarkdown(null); setLocalFileContent(""); + setHasCopied(false); + setIsSourceEditing(false); initialLoadDone.current = false; changeCountRef.current = 0; @@ -179,6 +194,14 @@ export function EditorPanelContent({ return () => controller.abort(); }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId, title]); + useEffect(() => { + return () => { + if (copyResetTimeoutRef.current) { + clearTimeout(copyResetTimeoutRef.current); + } + }; + }, []); + const handleMarkdownChange = useCallback((md: string) => { markdownRef.current = md; if (!initialLoadDone.current) return; @@ -187,6 +210,22 @@ export function EditorPanelContent({ setEditedMarkdown(md); }, []); + const handleCopy = useCallback(async () => { + try { + const textToCopy = markdownRef.current ?? editorDoc?.source_markdown ?? ""; + await navigator.clipboard.writeText(textToCopy); + setHasCopied(true); + if (copyResetTimeoutRef.current) { + clearTimeout(copyResetTimeoutRef.current); + } + copyResetTimeoutRef.current = setTimeout(() => { + setHasCopied(false); + }, 1400); + } catch (err) { + console.error("Error copying content:", err); + } + }, [editorDoc?.source_markdown]); + const handleSave = useCallback(async (options?: { silent?: boolean }) => { setSaving(true); try { @@ -209,7 +248,7 @@ export function EditorPanelContent({ prev ? { ...prev, source_markdown: contentToSave } : prev ); setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current); - return; + return true; } if (!searchSpaceId || !documentId) { throw new Error("Missing document context"); @@ -239,9 +278,11 @@ export function EditorPanelContent({ setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); setEditedMarkdown(null); toast.success("Document saved! Reindexing in background..."); + return true; } catch (err) { console.error("Error saving document:", err); toast.error(err instanceof Error ? err.message : "Failed to save document"); + return false; } finally { setSaving(false); } @@ -252,26 +293,111 @@ export function EditorPanelContent({ EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && !isLargeDocument : false; + const hasUnsavedChanges = editedMarkdown !== null; + const showDesktopHeader = !!onClose; + const isSourceCodeMode = editorRenderMode === "source_code"; + const showEditingActions = isSourceCodeMode && isSourceEditing; const localFileLanguage = inferMonacoLanguageFromPath(localFilePath); return ( <> -
-
-

{displayTitle}

+ {showDesktopHeader ? ( +
+
+

File

+
+ +
+
+
+
+

{displayTitle}

+
+
+ {showEditingActions ? ( + <> + + + + ) : ( + <> + + {isSourceCodeMode && ( + + )} + + )} + {!showEditingActions && !isLocalFileMode && editorDoc?.document_type && documentId && ( + + )} +
+
-
- {!isLocalFileMode && editorDoc?.document_type && documentId && ( - - )} - {onClose && ( - - )} + ) : ( +
+
+

{displayTitle}

+
+
+ {!isLocalFileMode && editorDoc?.document_type && documentId && ( + + )} +
-
+ )}
{isLoading ? ( @@ -360,8 +486,10 @@ export function EditorPanelContent({ path={localFilePath ?? "local-file.txt"} language={localFileLanguage} value={localFileContent} - onSave={() => handleSave({ silent: true })} - saveMode="auto" + onSave={() => { + void handleSave({ silent: true }); + }} + readOnly={!isSourceEditing} onChange={(next) => { markdownRef.current = next; setLocalFileContent(next); @@ -379,7 +507,9 @@ export function EditorPanelContent({ readOnly={false} placeholder="Start writing..." editorVariant="default" - onSave={handleSave} + onSave={() => { + void handleSave(); + }} hasUnsavedChanges={editedMarkdown !== null} isSaving={saving} defaultEditing={true} diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index 2c1f52989..11f9266b6 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -17,8 +17,6 @@ interface SourceCodeEditorProps { readOnly?: boolean; fontSize?: number; onSave?: () => Promise | void; - saveMode?: "manual" | "auto" | "both"; - autoSaveDelayMs?: number; } export function SourceCodeEditor({ @@ -29,64 +27,78 @@ export function SourceCodeEditor({ readOnly = false, fontSize = 12, onSave, - saveMode = "manual", - autoSaveDelayMs = 800, }: SourceCodeEditorProps) { const { resolvedTheme } = useTheme(); - const saveTimerRef = useRef | null>(null); const onSaveRef = useRef(onSave); - const skipNextAutoSaveRef = useRef(true); + const monacoRef = useRef(null); useEffect(() => { onSaveRef.current = onSave; }, [onSave]); - useEffect(() => { - skipNextAutoSaveRef.current = true; - }, [path]); + const resolveCssColorToHex = (cssColorValue: string): string | null => { + if (typeof document === "undefined") return null; + const probe = document.createElement("div"); + probe.style.color = cssColorValue; + probe.style.position = "absolute"; + probe.style.pointerEvents = "none"; + probe.style.opacity = "0"; + document.body.appendChild(probe); + const computedColor = getComputedStyle(probe).color; + probe.remove(); + const match = computedColor.match(/rgba?\((\d+),\s*(\d+),\s*(\d+)/i); + if (!match) return null; + const toHex = (value: string) => Number(value).toString(16).padStart(2, "0"); + return `#${toHex(match[1])}${toHex(match[2])}${toHex(match[3])}`; + }; + + const applySidebarTheme = (monaco: any) => { + const isDark = resolvedTheme === "dark"; + const themeName = isDark ? "surfsense-dark" : "surfsense-light"; + const fallbackBg = isDark ? "#1e1e1e" : "#ffffff"; + const sidebarBgHex = resolveCssColorToHex("var(--sidebar)") ?? fallbackBg; + monaco.editor.defineTheme(themeName, { + base: isDark ? "vs-dark" : "vs", + inherit: true, + rules: [], + colors: { + "editor.background": sidebarBgHex, + "editorGutter.background": sidebarBgHex, + "minimap.background": sidebarBgHex, + "editorLineNumber.background": sidebarBgHex, + "editor.lineHighlightBackground": "#00000000", + }, + }); + monaco.editor.setTheme(themeName); + }; useEffect(() => { - if (readOnly || !onSaveRef.current) return; - if (saveMode !== "auto" && saveMode !== "both") return; + if (!monacoRef.current) return; + applySidebarTheme(monacoRef.current); + }, [resolvedTheme]); - if (skipNextAutoSaveRef.current) { - skipNextAutoSaveRef.current = false; - return; - } - - if (saveTimerRef.current) { - clearTimeout(saveTimerRef.current); - } - - saveTimerRef.current = setTimeout(() => { - void onSaveRef.current?.(); - saveTimerRef.current = null; - }, autoSaveDelayMs); - - return () => { - if (saveTimerRef.current) { - clearTimeout(saveTimerRef.current); - saveTimerRef.current = null; - } - }; - }, [autoSaveDelayMs, readOnly, saveMode, value]); - - const isManualSaveEnabled = !!onSave && !readOnly && (saveMode === "manual" || saveMode === "both"); + const isManualSaveEnabled = !!onSave && !readOnly; return ( -
+
onChange(next ?? "")} loading={
} + beforeMount={(monaco) => { + monacoRef.current = monaco; + applySidebarTheme(monaco); + }} onMount={(editor, monaco) => { + monacoRef.current = monaco; + applySidebarTheme(monaco); if (!isManualSaveEnabled) return; editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, () => { void onSaveRef.current?.(); diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index 2394480b2..c2422bf34 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -94,7 +94,7 @@ export function RightPanelExpandButton() { Expand panel - Expand panel + Expand panel
); From 06b509213cf506c7ddd6f6eaabd2a648f8d5dbca Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:52:55 +0530 Subject: [PATCH 125/299] feat(editor): add mode toggle functionality and improve editor state management --- .../components/editor-panel/editor-panel.tsx | 123 +++++++++++++----- .../components/editor/plate-editor.tsx | 5 +- .../editor/plugins/fixed-toolbar-kit.tsx | 24 +++- 3 files changed, 116 insertions(+), 36 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index b83c4b1d7..0170d13da 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -88,7 +88,7 @@ export function EditorPanelContent({ const [error, setError] = useState(null); const [saving, setSaving] = useState(false); const [downloading, setDownloading] = useState(false); - const [isSourceEditing, setIsSourceEditing] = useState(false); + const [isEditing, setIsEditing] = useState(false); const [editedMarkdown, setEditedMarkdown] = useState(null); const [localFileContent, setLocalFileContent] = useState(""); @@ -111,7 +111,7 @@ export function EditorPanelContent({ setEditedMarkdown(null); setLocalFileContent(""); setHasCopied(false); - setIsSourceEditing(false); + setIsEditing(false); initialLoadDone.current = false; changeCountRef.current = 0; @@ -295,10 +295,18 @@ export function EditorPanelContent({ : false; const hasUnsavedChanges = editedMarkdown !== null; const showDesktopHeader = !!onClose; - const isSourceCodeMode = editorRenderMode === "source_code"; - const showEditingActions = isSourceCodeMode && isSourceEditing; + const showEditingActions = isEditableType && isEditing; const localFileLanguage = inferMonacoLanguageFromPath(localFilePath); + const handleCancelEditing = useCallback(() => { + const savedContent = editorDoc?.source_markdown ?? ""; + markdownRef.current = savedContent; + setLocalFileContent(savedContent); + setEditedMarkdown(null); + changeCountRef.current = 0; + setIsEditing(false); + }, [editorDoc?.source_markdown]); + return ( <> {showDesktopHeader ? ( @@ -323,13 +331,7 @@ export function EditorPanelContent({ variant="ghost" size="sm" className="h-6 px-2 text-xs" - onClick={() => { - const savedContent = editorDoc?.source_markdown ?? ""; - markdownRef.current = savedContent; - setLocalFileContent(savedContent); - setEditedMarkdown(null); - setIsSourceEditing(false); - }} + onClick={handleCancelEditing} disabled={saving} > Cancel @@ -340,7 +342,7 @@ export function EditorPanelContent({ className="relative h-6 w-[56px] px-0 text-xs" onClick={async () => { const saveSucceeded = await handleSave({ silent: true }); - if (saveSucceeded) setIsSourceEditing(false); + if (saveSucceeded) setIsEditing(false); }} disabled={saving || !hasUnsavedChanges} > @@ -364,15 +366,19 @@ export function EditorPanelContent({ {hasCopied ? "Copied file contents" : "Copy file contents"} - {isSourceCodeMode && ( + {isEditableType && ( )} @@ -389,11 +395,69 @@ export function EditorPanelContent({

{displayTitle}

- {!isLocalFileMode && editorDoc?.document_type && documentId && ( - + {showEditingActions ? ( + <> + + + + ) : ( + <> + + {isEditableType && ( + + )} + {!isLocalFileMode && editorDoc?.document_type && documentId && ( + + )} + )}
@@ -489,7 +553,7 @@ export function EditorPanelContent({ onSave={() => { void handleSave({ silent: true }); }} - readOnly={!isSourceEditing} + readOnly={!isEditing} onChange={(next) => { markdownRef.current = next; setLocalFileContent(next); @@ -500,19 +564,15 @@ export function EditorPanelContent({
) : isEditableType ? ( { - void handleSave(); - }} - hasUnsavedChanges={editedMarkdown !== null} - isSaving={saving} - defaultEditing={true} + allowModeToggle={false} + defaultEditing={isEditing} className="[&_[role=toolbar]]:!bg-sidebar" /> ) : ( @@ -561,6 +621,8 @@ function MobileEditorDrawer() { const panelState = useAtomValue(editorPanelAtom); const closePanel = useSetAtom(closeEditorPanelAtom); + if (panelState.kind === "local_file") return null; + const hasTarget = panelState.kind === "document" ? !!panelState.documentId && !!panelState.searchSpaceId @@ -604,6 +666,7 @@ export function EditorPanel() { : !!panelState.localFilePath; if (!panelState.isOpen || !hasTarget) return null; + if (!isDesktop && panelState.kind === "local_file") return null; if (isDesktop) { return ; @@ -620,7 +683,7 @@ export function MobileEditorPanel() { ? !!panelState.documentId && !!panelState.searchSpaceId : !!panelState.localFilePath; - if (isDesktop || !panelState.isOpen || !hasTarget) return null; + if (isDesktop || !panelState.isOpen || !hasTarget || panelState.kind === "local_file") return null; return ; } diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 61f84126c..371326bd3 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -42,6 +42,8 @@ export interface PlateEditorProps { hasUnsavedChanges?: boolean; /** Whether a save is in progress */ isSaving?: boolean; + /** Whether edit/view mode toggle UI should be available in toolbars. */ + allowModeToggle?: boolean; /** Start the editor in editing mode instead of viewing mode. Ignored when readOnly is true. */ defaultEditing?: boolean; /** @@ -91,6 +93,7 @@ export function PlateEditor({ onSave, hasUnsavedChanges = false, isSaving = false, + allowModeToggle = true, defaultEditing = false, preset = "full", extraPlugins = [], @@ -174,7 +177,7 @@ export function PlateEditor({ }, [html, markdown, editor]); // When not forced read-only, the user can toggle between editing/viewing. - const canToggleMode = !readOnly; + const canToggleMode = !readOnly && allowModeToggle; const contextProviderValue = useMemo( () => ({ diff --git a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx index 85e0a08f2..8b776a456 100644 --- a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx +++ b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx @@ -1,19 +1,33 @@ "use client"; import { createPlatePlugin } from "platejs/react"; +import { useEditorReadOnly } from "platejs/react"; +import { useEditorSave } from "@/components/editor/editor-save-context"; import { FixedToolbar } from "@/components/ui/fixed-toolbar"; import { FixedToolbarButtons } from "@/components/ui/fixed-toolbar-buttons"; +function ConditionalFixedToolbar() { + const readOnly = useEditorReadOnly(); + const { onSave, hasUnsavedChanges, canToggleMode } = useEditorSave(); + + const hasVisibleControls = + !readOnly || canToggleMode || (!!onSave && hasUnsavedChanges && !readOnly); + + if (!hasVisibleControls) return null; + + return ( + + + + ); +} + export const FixedToolbarKit = [ createPlatePlugin({ key: "fixed-toolbar", render: { - beforeEditable: () => ( - - - - ), + beforeEditable: () => , }, }), ]; From 0381632bc2a199bf1a93b7970aedba390b262e30 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 20:03:18 +0530 Subject: [PATCH 126/299] refactor(editor): replace Loader2 with Spinner component and enhance save button visibility --- .../components/editor-panel/editor-panel.tsx | 12 +-- .../components/report-panel/report-panel.tsx | 76 +++++++++++++++++-- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 0170d13da..50ee158c4 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -7,7 +7,6 @@ import { Download, FileQuestionMark, FileText, - Loader2, Pencil, RefreshCw, XIcon, @@ -22,6 +21,7 @@ import { MarkdownViewer } from "@/components/markdown-viewer"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; +import { Spinner } from "@/components/ui/spinner"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; @@ -346,8 +346,8 @@ export function EditorPanelContent({ }} disabled={saving || !hasUnsavedChanges} > - Save - {saving && } + Save + {saving && } ) : ( @@ -416,8 +416,8 @@ export function EditorPanelContent({ }} disabled={saving || !hasUnsavedChanges} > - Save - {saving && } + Save + {saving && } ) : ( @@ -534,7 +534,7 @@ export function EditorPanelContent({ }} > {downloading ? ( - + ) : ( )} diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 591155757..709b10467 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue, useSetAtom } from "jotai"; -import { ChevronDownIcon, XIcon } from "lucide-react"; +import { ChevronDownIcon, Pencil, XIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; @@ -125,6 +125,7 @@ export function ReportPanelContent({ // Editor state — tracks the latest markdown from the Plate editor const [editedMarkdown, setEditedMarkdown] = useState(null); + const [isEditing, setIsEditing] = useState(false); // Read-only when public (shareToken) OR shared (SEARCH_SPACE visibility) const currentThreadState = useAtomValue(currentThreadAtom); @@ -188,6 +189,7 @@ export function ReportPanelContent({ // Reset edited markdown when switching versions or reports useEffect(() => { setEditedMarkdown(null); + setIsEditing(false); }, [activeReportId]); // Copy markdown content (uses latest editor content) @@ -257,7 +259,7 @@ export function ReportPanelContent({ // Save edited report content const handleSave = useCallback(async () => { - if (!currentMarkdown || !activeReportId) return; + if (!currentMarkdown || !activeReportId) return false; setSaving(true); try { const response = await authenticatedFetch( @@ -278,9 +280,11 @@ export function ReportPanelContent({ setReportContent((prev) => (prev ? { ...prev, content: currentMarkdown } : prev)); setEditedMarkdown(null); toast.success("Report saved successfully"); + return true; } catch (err) { console.error("Error saving report:", err); toast.error(err instanceof Error ? err.message : "Failed to save report"); + return false; } finally { setSaving(false); } @@ -289,6 +293,14 @@ export function ReportPanelContent({ const activeVersionIndex = versions.findIndex((v) => v.id === activeReportId); const isPublic = !!shareToken; const btnBg = isPublic ? "bg-main-panel" : "bg-sidebar"; + const isResume = reportContent?.content_type === "typst"; + const showReportEditingTier = !isResume; + const hasUnsavedChanges = editedMarkdown !== null; + + const handleCancelEditing = useCallback(() => { + setEditedMarkdown(null); + setIsEditing(false); + }, []); return ( <> @@ -383,6 +395,58 @@ export function ReportPanelContent({ )}
+ {showReportEditingTier && ( +
+
+

+ {reportContent?.title || title} +

+
+
+ {!isReadOnly && + (isEditing ? ( + <> + + + + ) : ( + + ))} +
+
+ )} + {/* Report content — skeleton/error/viewer/editor shown only in this area */}
{isLoading ? ( @@ -406,15 +470,15 @@ export function ReportPanelContent({
) : ( ) From a1d3356bf55b8ebb4f70d37e2987f115609afde6 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 20:13:29 +0530 Subject: [PATCH 127/299] feat(editor): add reserveToolbarSpace option to enhance toolbar visibility management --- .../components/editor-panel/editor-panel.tsx | 1 + .../components/editor/editor-save-context.tsx | 3 +++ .../components/editor/plate-editor.tsx | 6 +++++- .../editor/plugins/fixed-toolbar-kit.tsx | 11 +++++++++-- .../components/report-panel/report-panel.tsx | 19 ++++++++++++++++++- 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 50ee158c4..d125ec143 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -572,6 +572,7 @@ export function EditorPanelContent({ placeholder="Start writing..." editorVariant="default" allowModeToggle={false} + reserveToolbarSpace defaultEditing={isEditing} className="[&_[role=toolbar]]:!bg-sidebar" /> diff --git a/surfsense_web/components/editor/editor-save-context.tsx b/surfsense_web/components/editor/editor-save-context.tsx index d53a4adce..b4b3935a4 100644 --- a/surfsense_web/components/editor/editor-save-context.tsx +++ b/surfsense_web/components/editor/editor-save-context.tsx @@ -11,12 +11,15 @@ interface EditorSaveContextValue { isSaving: boolean; /** Whether the user can toggle between editing and viewing modes */ canToggleMode: boolean; + /** Whether fixed-toolbar space should be reserved even when controls are hidden */ + reserveToolbarSpace: boolean; } export const EditorSaveContext = createContext({ hasUnsavedChanges: false, isSaving: false, canToggleMode: false, + reserveToolbarSpace: false, }); export function useEditorSave() { diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 371326bd3..481a420fb 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -44,6 +44,8 @@ export interface PlateEditorProps { isSaving?: boolean; /** Whether edit/view mode toggle UI should be available in toolbars. */ allowModeToggle?: boolean; + /** Reserve fixed-toolbar vertical space even when controls are hidden. */ + reserveToolbarSpace?: boolean; /** Start the editor in editing mode instead of viewing mode. Ignored when readOnly is true. */ defaultEditing?: boolean; /** @@ -94,6 +96,7 @@ export function PlateEditor({ hasUnsavedChanges = false, isSaving = false, allowModeToggle = true, + reserveToolbarSpace = false, defaultEditing = false, preset = "full", extraPlugins = [], @@ -185,8 +188,9 @@ export function PlateEditor({ hasUnsavedChanges, isSaving, canToggleMode, + reserveToolbarSpace, }), - [onSave, hasUnsavedChanges, isSaving, canToggleMode] + [onSave, hasUnsavedChanges, isSaving, canToggleMode, reserveToolbarSpace] ); return ( diff --git a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx index 8b776a456..bdda0263d 100644 --- a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx +++ b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx @@ -9,12 +9,19 @@ import { FixedToolbarButtons } from "@/components/ui/fixed-toolbar-buttons"; function ConditionalFixedToolbar() { const readOnly = useEditorReadOnly(); - const { onSave, hasUnsavedChanges, canToggleMode } = useEditorSave(); + const { onSave, hasUnsavedChanges, canToggleMode, reserveToolbarSpace } = useEditorSave(); const hasVisibleControls = !readOnly || canToggleMode || (!!onSave && hasUnsavedChanges && !readOnly); - if (!hasVisibleControls) return null; + if (!hasVisibleControls) { + if (!reserveToolbarSpace) return null; + return ( + +
+ + ); + } return ( diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 709b10467..0f6614ebf 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -116,6 +116,7 @@ export function ReportPanelContent({ const [exporting, setExporting] = useState(null); const [saving, setSaving] = useState(false); const copyTimerRef = useRef | undefined>(undefined); + const changeCountRef = useRef(0); useEffect(() => { return () => { @@ -190,8 +191,21 @@ export function ReportPanelContent({ useEffect(() => { setEditedMarkdown(null); setIsEditing(false); + changeCountRef.current = 0; }, [activeReportId]); + const handleReportMarkdownChange = useCallback( + (nextMarkdown: string) => { + if (!isEditing) return; + changeCountRef.current += 1; + // Plate may emit an initial normalize/serialize change on mount. + if (changeCountRef.current <= 1) return; + const savedMarkdown = reportContent?.content ?? ""; + setEditedMarkdown(nextMarkdown === savedMarkdown ? null : nextMarkdown); + }, + [isEditing, reportContent?.content] + ); + // Copy markdown content (uses latest editor content) const handleCopy = useCallback(async () => { if (!currentMarkdown) return; @@ -299,6 +313,7 @@ export function ReportPanelContent({ const handleCancelEditing = useCallback(() => { setEditedMarkdown(null); + changeCountRef.current = 0; setIsEditing(false); }, []); @@ -436,6 +451,7 @@ export function ReportPanelContent({ className="size-6" onClick={() => { setEditedMarkdown(null); + changeCountRef.current = 0; setIsEditing(true); }} > @@ -473,11 +489,12 @@ export function ReportPanelContent({ key={`report-${activeReportId}-${isEditing ? "editing" : "viewing"}`} preset="full" markdown={reportContent.content} - onMarkdownChange={setEditedMarkdown} + onMarkdownChange={handleReportMarkdownChange} readOnly={!isEditing} placeholder="Report content..." editorVariant="default" allowModeToggle={false} + reserveToolbarSpace defaultEditing={isEditing} className="[&_[role=toolbar]]:!bg-sidebar" /> From b5921bf1399559c31c3a10afc03ba61af49b5fbf Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 20:47:00 +0530 Subject: [PATCH 128/299] feat(markdown): enhance code block rendering for local web files and improve inline code styling --- .../components/assistant-ui/markdown-text.tsx | 23 +++++++++++++++++-- .../components/editor-panel/editor-panel.tsx | 2 +- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index a2ce30111..8f2184bd3 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -405,6 +405,14 @@ const defaultComponents = memoizeMarkdownComponents({ const openEditorPanel = useSetAtom(openEditorPanelAtom); const params = useParams(); const electronAPI = useElectronAPI(); + const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text"; + const codeString = String(children).replace(/\n$/, ""); + const isWebLocalFileCodeBlock = + isCodeBlock && + !electronAPI && + isVirtualFilePathToken(codeString.trim()) && + !codeString.trim().startsWith("//") && + !codeString.includes("\n"); if (!isCodeBlock) { const inlineValue = String(children ?? "").trim(); const isLocalPath = @@ -451,8 +459,19 @@ const defaultComponents = memoizeMarkdownComponents({ ); } - const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text"; - const codeString = String(children).replace(/\n$/, ""); + if (isWebLocalFileCodeBlock) { + return ( + + {codeString.trim()} + + ); + } return (
-

File

+

File

diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarCollapseButton.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarCollapseButton.tsx index a01937cd6..0eb409349 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarCollapseButton.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarCollapseButton.tsx @@ -1,6 +1,6 @@ "use client"; -import { PanelLeft, PanelLeftClose } from "lucide-react"; +import { PanelLeft } from "lucide-react"; import { useTranslations } from "next-intl"; import { Button } from "@/components/ui/button"; import { ShortcutKbd } from "@/components/ui/shortcut-kbd"; @@ -23,7 +23,7 @@ export function SidebarCollapseButton({ const button = ( ); diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 0f6614ebf..c7a8509ed 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue, useSetAtom } from "jotai"; -import { ChevronDownIcon, Pencil, XIcon } from "lucide-react"; +import { Check, ChevronDownIcon, Copy, Pencil, XIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; @@ -306,7 +306,6 @@ export function ReportPanelContent({ const activeVersionIndex = versions.findIndex((v) => v.id === activeReportId); const isPublic = !!shareToken; - const btnBg = isPublic ? "bg-main-panel" : "bg-sidebar"; const isResume = reportContent?.content_type === "typst"; const showReportEditingTier = !isResume; const hasUnsavedChanges = editedMarkdown !== null; @@ -322,19 +321,6 @@ export function ReportPanelContent({ {/* Action bar — always visible; buttons are disabled while loading */}
- {/* Copy button — hidden for Typst (resume) */} - {reportContent?.content_type !== "typst" && ( - - )} - {/* Export — plain button for resume (typst), dropdown for others */} {reportContent?.content_type === "typst" ? ( @@ -353,7 +339,7 @@ export function ReportPanelContent({ variant="outline" size="sm" disabled={isLoading || !reportContent?.content} - className={`h-8 px-3.5 py-4 text-[15px] gap-1.5 ${btnBg} select-none`} + className={`h-8 px-3.5 py-4 text-[15px] gap-1.5 ${isPublic ? "bg-main-panel" : "bg-sidebar"} select-none`} > Export @@ -379,7 +365,7 @@ export function ReportPanelContent({
+ {!isEditing && ( + + )} {!isReadOnly && (isEditing ? ( <> From 84145566e3e7666a1b0d8cb514dbd70dbf44a948 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 22:27:58 +0530 Subject: [PATCH 130/299] feat(editor): implement local filesystem trust dialog and enhance filesystem mode selection --- .../components/assistant-ui/thread.tsx | 265 +++++++++++++++--- 1 file changed, 222 insertions(+), 43 deletions(-) diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 094d99a29..9df41ee55 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -12,11 +12,15 @@ import { AlertCircle, ArrowDownIcon, ArrowUpIcon, + Check, ChevronDown, ChevronUp, Clipboard, Dot, + Folder, + FolderPlus, Globe, + Laptop, Plus, Settings2, SquareIcon, @@ -66,6 +70,16 @@ import { } from "@/components/new-chat/document-mention-picker"; import { PromptPicker, type PromptPickerRef } from "@/components/new-chat/prompt-picker"; import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; import { @@ -100,6 +114,8 @@ type ComposerFilesystemSettings = { updatedAt: string; }; +const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1"; + export const Thread: FC = () => { return ; }; @@ -371,6 +387,8 @@ const Composer: FC = () => { const [filesystemSettings, setFilesystemSettings] = useState( null ); + const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); + const [pendingLocalPath, setPendingLocalPath] = useState(null); const [clipboardInitialText, setClipboardInitialText] = useState(); const clipboardLoadedRef = useRef(false); useEffect(() => { @@ -388,7 +406,7 @@ const Composer: FC = () => { let mounted = true; electronAPI .getAgentFilesystemSettings() - .then((settings) => { + .then((settings: ComposerFilesystemSettings) => { if (!mounted) return; setFilesystemSettings(settings); }) @@ -405,22 +423,66 @@ const Composer: FC = () => { }; }, [electronAPI]); - const handleFilesystemModeChange = useCallback( - async (mode: "cloud" | "desktop_local_folder") => { + const hasLocalFilesystemTrust = useCallback(() => { + try { + return window.localStorage.getItem(LOCAL_FILESYSTEM_TRUST_KEY) === "true"; + } catch { + return false; + } + }, []); + + const applyLocalRootPath = useCallback( + async (path: string) => { if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ mode }); + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPath: path, + }); setFilesystemSettings(updated); }, [electronAPI] ); - const handlePickFilesystemRoot = useCallback(async () => { - if (!electronAPI?.pickAgentFilesystemRoot || !electronAPI?.setAgentFilesystemSettings) return; + const runSwitchToLocalMode = useCallback(async () => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder" }); + setFilesystemSettings(updated); + }, [electronAPI]); + + const runPickLocalRoot = useCallback(async () => { + if (!electronAPI?.pickAgentFilesystemRoot) return; const picked = await electronAPI.pickAgentFilesystemRoot(); if (!picked) return; + await applyLocalRootPath(picked); + }, [applyLocalRootPath, electronAPI]); + + const handleFilesystemModeChange = useCallback( + async (mode: "cloud" | "desktop_local_folder") => { + if (!electronAPI?.setAgentFilesystemSettings) return; + if (mode === "desktop_local_folder") return void runSwitchToLocalMode(); + const updated = await electronAPI.setAgentFilesystemSettings({ mode }); + setFilesystemSettings(updated); + }, + [electronAPI, runSwitchToLocalMode] + ); + + const handlePickFilesystemRoot = useCallback(async () => { + if (hasLocalFilesystemTrust()) { + await runPickLocalRoot(); + return; + } + if (!electronAPI?.pickAgentFilesystemRoot) return; + const picked = await electronAPI.pickAgentFilesystemRoot(); + if (!picked) return; + setPendingLocalPath(picked); + setLocalTrustDialogOpen(true); + }, [electronAPI, hasLocalFilesystemTrust, runPickLocalRoot]); + + const handleClearFilesystemRoot = useCallback(async () => { + if (!electronAPI?.setAgentFilesystemSettings) return; const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder", - localRootPath: picked, + localRootPath: null, }); setFilesystemSettings(updated); }, [electronAPI]); @@ -720,44 +782,161 @@ const Composer: FC = () => { members={members ?? []} /> {electronAPI && filesystemSettings ? ( -
- - -
- +
+ + + + + + handleFilesystemModeChange("cloud")} + className="flex items-center justify-between" + > + + + Cloud + + {filesystemSettings.mode === "cloud" && } + + handleFilesystemModeChange("desktop_local_folder")} + className="flex items-center justify-between" + > + + + Local + + {filesystemSettings.mode === "desktop_local_folder" && ( + + )} + + + + + {filesystemSettings.mode === "desktop_local_folder" && ( + <> +
+
+ {filesystemSettings.localRootPath ? ( + <> +
+ + + {filesystemSettings.localRootPath.split("/").at(-1) || + filesystemSettings.localRootPath} + + +
+ + + ) : ( + + )} +
+ + )}
) : null} + { + setLocalTrustDialogOpen(open); + if (!open) { + setPendingLocalPath(null); + } + }} + > + + + Trust this workspace? + + Local mode can read and edit files inside the folders you select. Continue only if + you trust this workspace and its contents. + + {(pendingLocalPath || filesystemSettings?.localRootPath) && ( + + Folder path: {pendingLocalPath || filesystemSettings?.localRootPath} + + )} + + + Cancel + { + try { + window.localStorage.setItem(LOCAL_FILESYSTEM_TRUST_KEY, "true"); + } catch {} + setLocalTrustDialogOpen(false); + const path = pendingLocalPath; + setPendingLocalPath(null); + if (path) { + await applyLocalRootPath(path); + } else { + await runPickLocalRoot(); + } + }} + > + I trust this workspace + + + + {showDocumentPopover && (
Date: Thu, 23 Apr 2026 22:49:59 +0530 Subject: [PATCH 131/299] feat(settings): add DesktopShortcutsContent component for managing hotkeys and update user settings dialog --- .../components/DesktopContent.tsx | 81 +------------ .../components/DesktopShortcutsContent.tsx | 108 ++++++++++++++++++ surfsense_web/app/desktop/login/page.tsx | 2 +- .../settings/user-settings-dialog.tsx | 19 ++- 4 files changed, 127 insertions(+), 83 deletions(-) create mode 100644 surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx index 63ca9f5df..3ec14076d 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx @@ -1,9 +1,7 @@ "use client"; -import { BrainCog, Power, Rocket, Zap } from "lucide-react"; import { useEffect, useState } from "react"; import { toast } from "sonner"; -import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Label } from "@/components/ui/label"; import { @@ -24,9 +22,6 @@ export function DesktopContent() { const [loading, setLoading] = useState(true); const [enabled, setEnabled] = useState(true); - const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS); - const [shortcutsLoaded, setShortcutsLoaded] = useState(false); - const [searchSpaces, setSearchSpaces] = useState([]); const [activeSpaceId, setActiveSpaceId] = useState(null); @@ -37,7 +32,6 @@ export function DesktopContent() { useEffect(() => { if (!api) { setLoading(false); - setShortcutsLoaded(true); return; } @@ -48,15 +42,13 @@ export function DesktopContent() { Promise.all([ api.getAutocompleteEnabled(), - api.getShortcuts?.() ?? Promise.resolve(null), api.getActiveSearchSpace?.() ?? Promise.resolve(null), searchSpacesApiService.getSearchSpaces(), hasAutoLaunchApi ? api.getAutoLaunch() : Promise.resolve(null), ]) - .then(([autoEnabled, config, spaceId, spaces, autoLaunch]) => { + .then(([autoEnabled, spaceId, spaces, autoLaunch]) => { if (!mounted) return; setEnabled(autoEnabled); - if (config) setShortcuts(config); setActiveSpaceId(spaceId); if (spaces) setSearchSpaces(spaces); if (autoLaunch) { @@ -65,12 +57,10 @@ export function DesktopContent() { setAutoLaunchSupported(autoLaunch.supported); } setLoading(false); - setShortcutsLoaded(true); }) .catch(() => { if (!mounted) return; setLoading(false); - setShortcutsLoaded(true); }); return () => { @@ -101,24 +91,6 @@ export function DesktopContent() { await api.setAutocompleteEnabled(checked); }; - const updateShortcut = ( - key: "generalAssist" | "quickAsk" | "autocomplete", - accelerator: string - ) => { - setShortcuts((prev) => { - const updated = { ...prev, [key]: accelerator }; - api.setShortcuts?.({ [key]: accelerator }).catch(() => { - toast.error("Failed to update shortcut"); - }); - return updated; - }); - toast.success("Shortcut updated"); - }; - - const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => { - updateShortcut(key, DEFAULT_SHORTCUTS[key]); - }; - const handleAutoLaunchToggle = async (checked: boolean) => { if (!autoLaunchSupported || !api.setAutoLaunch) { toast.error("Please update the desktop app to configure launch on startup"); @@ -196,7 +168,6 @@ export function DesktopContent() { - Launch on Startup @@ -245,56 +216,6 @@ export function DesktopContent() { - {/* Keyboard Shortcuts */} - - - Keyboard Shortcuts - - Customize the global keyboard shortcuts for desktop features. - - - - {shortcutsLoaded ? ( -
- updateShortcut("generalAssist", accel)} - onReset={() => resetShortcut("generalAssist")} - defaultValue={DEFAULT_SHORTCUTS.generalAssist} - label="General Assist" - description="Launch SurfSense instantly from any application" - icon={Rocket} - /> - updateShortcut("quickAsk", accel)} - onReset={() => resetShortcut("quickAsk")} - defaultValue={DEFAULT_SHORTCUTS.quickAsk} - label="Quick Assist" - description="Select text anywhere, then ask AI to explain, rewrite, or act on it" - icon={Zap} - /> - updateShortcut("autocomplete", accel)} - onReset={() => resetShortcut("autocomplete")} - defaultValue={DEFAULT_SHORTCUTS.autocomplete} - label="Extreme Assist" - description="AI drafts text using your screen context and knowledge base" - icon={BrainCog} - /> -

- Click a shortcut and press a new key combination to change it. -

-
- ) : ( -
- -
- )} -
-
- {/* Extreme Assist Toggle */} diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx new file mode 100644 index 000000000..773665e63 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx @@ -0,0 +1,108 @@ +"use client"; + +import { BrainCog, Info, Rocket, Zap } from "lucide-react"; +import { useEffect, useState } from "react"; +import { toast } from "sonner"; +import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Spinner } from "@/components/ui/spinner"; +import { useElectronAPI } from "@/hooks/use-platform"; + +export function DesktopShortcutsContent() { + const api = useElectronAPI(); + const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS); + const [shortcutsLoaded, setShortcutsLoaded] = useState(false); + + useEffect(() => { + if (!api) { + setShortcutsLoaded(true); + return; + } + + let mounted = true; + (api.getShortcuts?.() ?? Promise.resolve(null)) + .then((config) => { + if (!mounted) return; + if (config) setShortcuts(config); + setShortcutsLoaded(true); + }) + .catch(() => { + if (!mounted) return; + setShortcutsLoaded(true); + }); + + return () => { + mounted = false; + }; + }, [api]); + + if (!api) { + return ( +
+

Hotkeys are only available in the SurfSense desktop app.

+
+ ); + } + + const updateShortcut = ( + key: "generalAssist" | "quickAsk" | "autocomplete", + accelerator: string + ) => { + setShortcuts((prev) => { + const updated = { ...prev, [key]: accelerator }; + api.setShortcuts?.({ [key]: accelerator }).catch(() => { + toast.error("Failed to update shortcut"); + }); + return updated; + }); + toast.success("Shortcut updated"); + }; + + const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => { + updateShortcut(key, DEFAULT_SHORTCUTS[key]); + }; + + return ( + shortcutsLoaded ? ( +
+ + + +

Click a shortcut and press a new key combination to change it.

+
+
+ updateShortcut("generalAssist", accel)} + onReset={() => resetShortcut("generalAssist")} + defaultValue={DEFAULT_SHORTCUTS.generalAssist} + label="General Assist" + description="Launch SurfSense instantly from any application" + icon={Rocket} + /> + updateShortcut("quickAsk", accel)} + onReset={() => resetShortcut("quickAsk")} + defaultValue={DEFAULT_SHORTCUTS.quickAsk} + label="Quick Assist" + description="Select text anywhere, then ask AI to explain, rewrite, or act on it" + icon={Zap} + /> + updateShortcut("autocomplete", accel)} + onReset={() => resetShortcut("autocomplete")} + defaultValue={DEFAULT_SHORTCUTS.autocomplete} + label="Extreme Assist" + description="AI drafts text using your screen context and knowledge base" + icon={BrainCog} + /> +
+ ) : ( +
+ +
+ ) + ); +} diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 8f68d20c1..1b43f89c0 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -152,7 +152,7 @@ export default function DesktopLoginPage() { {shortcutsLoaded ? (

- Keyboard Shortcuts + Hotkeys

+ import("@/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent").then( + (m) => ({ default: m.DesktopShortcutsContent }) + ), + { ssr: false } +); const MemoryContent = dynamic( () => import("@/app/dashboard/[search_space_id]/user-settings/components/MemoryContent").then( @@ -93,7 +100,14 @@ export function UserSettingsDialog() { icon: , }, ...(isDesktop - ? [{ value: "desktop", label: "Desktop", icon: }] + ? [ + { value: "desktop", label: "Desktop", icon: }, + { + value: "desktop-shortcuts", + label: "Hotkeys", + icon: , + }, + ] : []), ], [t, isDesktop] @@ -116,6 +130,7 @@ export function UserSettingsDialog() { {state.initialTab === "memory" && } {state.initialTab === "purchases" && } {state.initialTab === "desktop" && } + {state.initialTab === "desktop-shortcuts" && }
); From 46056ee514cdd29e21dfac0b69eeaf63ce266b9a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 23 Apr 2026 23:52:49 +0530 Subject: [PATCH 132/299] fix(settings): update user settings dialog labels and enhance DesktopShortcutsContent component for better hotkey management --- .../components/DesktopContent.tsx | 2 +- .../components/DesktopShortcutsContent.tsx | 194 ++++++++++++++---- .../settings/user-settings-dialog.tsx | 6 +- 3 files changed, 161 insertions(+), 41 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx index 3ec14076d..9861f5536 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx @@ -72,7 +72,7 @@ export function DesktopContent() { return (

- Desktop settings are only available in the SurfSense desktop app. + App preferences are only available in the SurfSense desktop app.

); diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx index 773665e63..f4981b8f0 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx @@ -1,17 +1,152 @@ "use client"; -import { BrainCog, Info, Rocket, Zap } from "lucide-react"; -import { useEffect, useState } from "react"; +import { ArrowBigUp, BrainCog, Command, Option, Rocket, RotateCcw, Zap } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; -import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder"; -import { Alert, AlertDescription } from "@/components/ui/alert"; +import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; +import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; +type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete"; +type ShortcutMap = typeof DEFAULT_SHORTCUTS; + +const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; icon: React.ElementType }> = [ + { key: "generalAssist", label: "General Assist", icon: Rocket }, + { key: "quickAsk", label: "Quick Assist", icon: Zap }, + { key: "autocomplete", label: "Extreme Assist", icon: BrainCog }, +]; + +type ShortcutToken = + | { kind: "text"; value: string } + | { kind: "icon"; value: "command" | "option" | "shift" }; + +function acceleratorToTokens(accel: string, isMac: boolean): ShortcutToken[] { + if (!accel) return []; + return accel.split("+").map((part) => { + if (part === "CommandOrControl") { + return isMac ? { kind: "icon", value: "command" as const } : { kind: "text", value: "Ctrl" }; + } + if (part === "Alt") { + return isMac ? { kind: "icon", value: "option" as const } : { kind: "text", value: "Alt" }; + } + if (part === "Shift") { + return isMac ? { kind: "icon", value: "shift" as const } : { kind: "text", value: "Shift" }; + } + if (part === "Space") return { kind: "text", value: "Space" }; + return { kind: "text", value: part.length === 1 ? part.toUpperCase() : part }; + }); +} + +function HotkeyRow({ + label, + value, + defaultValue, + icon: Icon, + isMac, + onChange, + onReset, +}: { + label: string; + value: string; + defaultValue: string; + icon: React.ElementType; + isMac: boolean; + onChange: (accelerator: string) => void; + onReset: () => void; +}) { + const [recording, setRecording] = useState(false); + const inputRef = useRef(null); + const isDefault = value === defaultValue; + const displayTokens = useMemo(() => acceleratorToTokens(value, isMac), [value, isMac]); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if (!recording) return; + e.preventDefault(); + e.stopPropagation(); + + if (e.key === "Escape") { + setRecording(false); + return; + } + + const accel = keyEventToAccelerator(e); + if (accel) { + onChange(accel); + setRecording(false); + } + }, + [onChange, recording] + ); + + return ( +
+
+
+ +
+

{label}

+
+
+ {!isDefault && ( + + )} + +
+
+ ); +} + export function DesktopShortcutsContent() { const api = useElectronAPI(); const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS); const [shortcutsLoaded, setShortcutsLoaded] = useState(false); + const isMac = api?.versions?.platform === "darwin"; useEffect(() => { if (!api) { @@ -21,7 +156,7 @@ export function DesktopShortcutsContent() { let mounted = true; (api.getShortcuts?.() ?? Promise.resolve(null)) - .then((config) => { + .then((config: ShortcutMap | null) => { if (!mounted) return; if (config) setShortcuts(config); setShortcutsLoaded(true); @@ -58,46 +193,27 @@ export function DesktopShortcutsContent() { toast.success("Shortcut updated"); }; - const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => { + const resetShortcut = (key: ShortcutKey) => { updateShortcut(key, DEFAULT_SHORTCUTS[key]); }; return ( shortcutsLoaded ? (
- - - -

Click a shortcut and press a new key combination to change it.

-
-
- updateShortcut("generalAssist", accel)} - onReset={() => resetShortcut("generalAssist")} - defaultValue={DEFAULT_SHORTCUTS.generalAssist} - label="General Assist" - description="Launch SurfSense instantly from any application" - icon={Rocket} - /> - updateShortcut("quickAsk", accel)} - onReset={() => resetShortcut("quickAsk")} - defaultValue={DEFAULT_SHORTCUTS.quickAsk} - label="Quick Assist" - description="Select text anywhere, then ask AI to explain, rewrite, or act on it" - icon={Zap} - /> - updateShortcut("autocomplete", accel)} - onReset={() => resetShortcut("autocomplete")} - defaultValue={DEFAULT_SHORTCUTS.autocomplete} - label="Extreme Assist" - description="AI drafts text using your screen context and knowledge base" - icon={BrainCog} - /> +
+ {HOTKEY_ROWS.map((row) => ( + updateShortcut(row.key, accel)} + onReset={() => resetShortcut(row.key)} + /> + ))} +
) : (
diff --git a/surfsense_web/components/settings/user-settings-dialog.tsx b/surfsense_web/components/settings/user-settings-dialog.tsx index a406f6352..cc36392ae 100644 --- a/surfsense_web/components/settings/user-settings-dialog.tsx +++ b/surfsense_web/components/settings/user-settings-dialog.tsx @@ -101,7 +101,11 @@ export function UserSettingsDialog() { }, ...(isDesktop ? [ - { value: "desktop", label: "Desktop", icon: }, + { + value: "desktop", + label: "App Preferences", + icon: , + }, { value: "desktop-shortcuts", label: "Hotkeys", From daac6b52691844edb530a47548606aeb2238f64e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:06:38 +0530 Subject: [PATCH 133/299] feat(login): implement customizable hotkey management in the login page with enhanced UI components --- surfsense_web/app/desktop/login/page.tsx | 241 +++++++++++++++++------ 1 file changed, 180 insertions(+), 61 deletions(-) diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 1b43f89c0..6d5e2abd4 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -2,13 +2,13 @@ import { IconBrandGoogleFilled } from "@tabler/icons-react"; import { useAtom } from "jotai"; -import { BrainCog, Eye, EyeOff, Rocket, Zap } from "lucide-react"; +import { ArrowBigUp, BrainCog, Command, Eye, EyeOff, Option, Rocket, RotateCcw, Zap } from "lucide-react"; import Image from "next/image"; import { useRouter } from "next/navigation"; -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; -import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder"; +import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; @@ -20,6 +20,157 @@ import { setBearerToken } from "@/lib/auth-utils"; import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; const isGoogleAuth = AUTH_TYPE === "GOOGLE"; +type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete"; +type ShortcutMap = typeof DEFAULT_SHORTCUTS; + +type ShortcutToken = + | { kind: "text"; value: string } + | { kind: "icon"; value: "command" | "option" | "shift" }; + +const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; icon: React.ElementType }> = [ + { + key: "generalAssist", + label: "General Assist", + description: "Launch SurfSense instantly from any application", + icon: Rocket, + }, + { + key: "quickAsk", + label: "Quick Assist", + description: "Select text anywhere, then ask AI to explain, rewrite, or act on it", + icon: Zap, + }, + { + key: "autocomplete", + label: "Extreme Assist", + description: "AI drafts text using your screen context and knowledge base", + icon: BrainCog, + }, +]; + +function acceleratorToTokens(accel: string, isMac: boolean): ShortcutToken[] { + if (!accel) return []; + return accel.split("+").map((part) => { + if (part === "CommandOrControl") { + return isMac ? { kind: "icon", value: "command" as const } : { kind: "text", value: "Ctrl" }; + } + if (part === "Alt") { + return isMac ? { kind: "icon", value: "option" as const } : { kind: "text", value: "Alt" }; + } + if (part === "Shift") { + return isMac ? { kind: "icon", value: "shift" as const } : { kind: "text", value: "Shift" }; + } + if (part === "Space") return { kind: "text", value: "Space" }; + return { kind: "text", value: part.length === 1 ? part.toUpperCase() : part }; + }); +} + +function HotkeyRow({ + label, + description, + value, + defaultValue, + icon: Icon, + isMac, + onChange, + onReset, +}: { + label: string; + description: string; + value: string; + defaultValue: string; + icon: React.ElementType; + isMac: boolean; + onChange: (accelerator: string) => void; + onReset: () => void; +}) { + const [recording, setRecording] = useState(false); + const inputRef = useRef(null); + const isDefault = value === defaultValue; + const displayTokens = useMemo(() => acceleratorToTokens(value, isMac), [value, isMac]); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if (!recording) return; + e.preventDefault(); + e.stopPropagation(); + + if (e.key === "Escape") { + setRecording(false); + return; + } + + const accel = keyEventToAccelerator(e); + if (accel) { + onChange(accel); + setRecording(false); + } + }, + [onChange, recording] + ); + + return ( +
+
+
+ +
+
+

{label}

+

{description}

+
+
+
+ {!isDefault && ( + + )} + +
+
+ ); +} export default function DesktopLoginPage() { const router = useRouter(); @@ -33,6 +184,7 @@ export default function DesktopLoginPage() { const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS); const [shortcutsLoaded, setShortcutsLoaded] = useState(false); + const isMac = api?.versions?.platform === "darwin"; useEffect(() => { if (!api?.getShortcuts) { @@ -41,7 +193,7 @@ export default function DesktopLoginPage() { } api .getShortcuts() - .then((config) => { + .then((config: ShortcutMap | null) => { if (config) setShortcuts(config); setShortcutsLoaded(true); }) @@ -117,18 +269,8 @@ export default function DesktopLoginPage() { }; return ( -
- {/* Subtle radial glow */} -
-
-
- -
+
+
{/* Header */}

Welcome to SurfSense Desktop

- Configure shortcuts, then sign in to get started. + Configure shortcuts, then sign in to get started

@@ -151,41 +293,24 @@ export default function DesktopLoginPage() { {/* ---- Shortcuts ---- */} {shortcutsLoaded ? (
-

+ {/*

Hotkeys -

-
- updateShortcut("generalAssist", accel)} - onReset={() => resetShortcut("generalAssist")} - defaultValue={DEFAULT_SHORTCUTS.generalAssist} - label="General Assist" - description="Launch SurfSense instantly from any application" - icon={Rocket} - /> - updateShortcut("quickAsk", accel)} - onReset={() => resetShortcut("quickAsk")} - defaultValue={DEFAULT_SHORTCUTS.quickAsk} - label="Quick Assist" - description="Select text anywhere, then ask AI to explain, rewrite, or act on it" - icon={Zap} - /> - updateShortcut("autocomplete", accel)} - onReset={() => resetShortcut("autocomplete")} - defaultValue={DEFAULT_SHORTCUTS.autocomplete} - label="Extreme Assist" - description="AI drafts text using your screen context and knowledge base" - icon={BrainCog} - /> +

*/} +
+ {HOTKEY_ROWS.map((row) => ( + updateShortcut(row.key, accel)} + onReset={() => resetShortcut(row.key)} + /> + ))}
-

- Click a shortcut and press a new key combination to change it. -

) : (
@@ -197,9 +322,9 @@ export default function DesktopLoginPage() { {/* ---- Auth ---- */}
-

+ {/*

Sign In -

+

*/} {isGoogleAuth ? (
- )} From 6721919398241bbc2696b19ea526915d95807f50 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:44:23 +0530 Subject: [PATCH 134/299] feat(filesystem): add multi-root local folder support in backend --- .../agents/new_chat/filesystem_backends.py | 21 +- .../agents/new_chat/filesystem_selection.py | 2 +- .../agents/new_chat/middleware/filesystem.py | 45 ++- .../multi_root_local_folder_backend.py | 328 ++++++++++++++++++ .../app/routes/new_chat_routes.py | 24 +- surfsense_backend/app/schemas/new_chat.py | 6 +- .../middleware/test_filesystem_backends.py | 26 +- 7 files changed, 422 insertions(+), 30 deletions(-) create mode 100644 surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py diff --git a/surfsense_backend/app/agents/new_chat/filesystem_backends.py b/surfsense_backend/app/agents/new_chat/filesystem_backends.py index 8af7e8558..0c32ef845 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_backends.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_backends.py @@ -9,26 +9,27 @@ from deepagents.backends.state import StateBackend from langgraph.prebuilt.tool_node import ToolRuntime from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection -from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend +from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( + MultiRootLocalFolderBackend, +) @lru_cache(maxsize=64) -def _cached_local_backend(root_path: str) -> LocalFolderBackend: - return LocalFolderBackend(root_path) +def _cached_multi_root_backend( + root_paths: tuple[str, ...], +) -> MultiRootLocalFolderBackend: + return MultiRootLocalFolderBackend(root_paths) def build_backend_resolver( selection: FilesystemSelection, -) -> Callable[[ToolRuntime], StateBackend | LocalFolderBackend]: +) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]: """Create deepagents backend resolver for the selected filesystem mode.""" - if ( - selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER - and selection.local_root_path is not None - ): + if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_root_paths: - def _resolve_local(_runtime: ToolRuntime) -> LocalFolderBackend: - return _cached_local_backend(selection.local_root_path or "") + def _resolve_local(_runtime: ToolRuntime) -> MultiRootLocalFolderBackend: + return _cached_multi_root_backend(selection.local_root_paths) return _resolve_local diff --git a/surfsense_backend/app/agents/new_chat/filesystem_selection.py b/surfsense_backend/app/agents/new_chat/filesystem_selection.py index 3094a0b29..4b8f42847 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_selection.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_selection.py @@ -26,7 +26,7 @@ class FilesystemSelection: mode: FilesystemMode = FilesystemMode.CLOUD client_platform: ClientPlatform = ClientPlatform.WEB - local_root_path: str | None = None + local_root_paths: tuple[str, ...] = () @property def is_local_mode(self) -> bool: diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 0fa2085fc..6c30b20ef 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -26,13 +26,16 @@ from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command from sqlalchemy import delete, select +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( + MultiRootLocalFolderBackend, +) from app.agents.new_chat.sandbox import ( _evict_sandbox_cache, delete_sandbox, get_or_create_sandbox, is_sandbox_enabled, ) -from app.agents.new_chat.filesystem_selection import FilesystemMode from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session from app.indexing_pipeline.document_chunker import chunk_text from app.utils.document_converters import ( @@ -222,6 +225,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): "\n\n## Local Folder Mode" "\n\nThis chat is running in desktop local-folder mode." " Keep all file operations local. Do not use save_document." + " Always use mount-prefixed absolute paths like //file.ext." + " If you are unsure which mounts are available, call ls('/') first." ) super().__init__( @@ -771,12 +776,30 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): """Only cloud mode persists file content to Document/Chunk tables.""" return self._filesystem_mode == FilesystemMode.CLOUD - @staticmethod - def _get_contract_suggested_path(runtime: ToolRuntime[None, FilesystemState]) -> str: + def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str: + backend = self._get_backend(runtime) + if isinstance(backend, MultiRootLocalFolderBackend): + return f"/{backend.default_mount()}" + return "" + + def _get_contract_suggested_path( + self, runtime: ToolRuntime[None, FilesystemState] + ) -> str: contract = runtime.state.get("file_operation_contract") or {} suggested = contract.get("suggested_path") if isinstance(suggested, str) and suggested.strip(): - return suggested.strip() + cleaned = suggested.strip() + if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + mount_prefix = self._default_mount_prefix(runtime) + if mount_prefix and cleaned.startswith("/") and not cleaned.startswith( + f"{mount_prefix}/" + ): + return f"{mount_prefix}{cleaned}" + return cleaned + if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + mount_prefix = self._default_mount_prefix(runtime) + if mount_prefix: + return f"{mount_prefix}/notes.md" return "/notes.md" def _resolve_write_target_path( @@ -787,6 +810,20 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): candidate = file_path.strip() if not candidate: return self._get_contract_suggested_path(runtime) + if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + backend = self._get_backend(runtime) + mount_prefix = self._default_mount_prefix(runtime) + if mount_prefix and not candidate.startswith("/"): + return f"{mount_prefix}/{candidate.lstrip('/')}" + if ( + mount_prefix + and isinstance(backend, MultiRootLocalFolderBackend) + and candidate.startswith("/") + ): + mount_names = backend.list_mounts() + first_segment = candidate.lstrip("/").split("/", 1)[0] + if first_segment not in mount_names: + return f"{mount_prefix}{candidate}" if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py new file mode 100644 index 000000000..2eb4e78dc --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -0,0 +1,328 @@ +"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +from deepagents.backends.protocol import ( + EditResult, + FileDownloadResponse, + FileInfo, + FileUploadResponse, + GrepMatch, + WriteResult, +) + +from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend + +_INVALID_PATH = "invalid_path" +_FILE_NOT_FOUND = "file_not_found" +_IS_DIRECTORY = "is_directory" + + +class MultiRootLocalFolderBackend: + """Route filesystem operations to one of several mounted local roots. + + Virtual paths are namespaced as: + - `//...` + where `` is derived from each selected root folder name. + """ + + def __init__(self, root_paths: tuple[str, ...]) -> None: + if not root_paths: + msg = "At least one local root path is required" + raise ValueError(msg) + self._mount_to_backend: dict[str, LocalFolderBackend] = {} + for raw_root in root_paths: + normalized_root = str(Path(raw_root).expanduser().resolve()) + base_mount = Path(normalized_root).name or "root" + mount = base_mount + suffix = 2 + while mount in self._mount_to_backend: + mount = f"{base_mount}-{suffix}" + suffix += 1 + self._mount_to_backend[mount] = LocalFolderBackend(normalized_root) + self._mount_order = tuple(self._mount_to_backend.keys()) + + def list_mounts(self) -> tuple[str, ...]: + return self._mount_order + + def default_mount(self) -> str: + return self._mount_order[0] + + def _mount_error(self) -> str: + mounts = ", ".join(f"/{mount}" for mount in self._mount_order) + return ( + "Path must start with one of the selected folders: " + f"{mounts}. Example: /{self._mount_order[0]}/file.txt" + ) + + def _split_mount_path(self, virtual_path: str) -> tuple[str, str]: + if not virtual_path.startswith("/"): + msg = f"Invalid path (must be absolute): {virtual_path}" + raise ValueError(msg) + rel = virtual_path.lstrip("/") + if not rel: + raise ValueError(self._mount_error()) + mount, _, remainder = rel.partition("/") + backend = self._mount_to_backend.get(mount) + if backend is None: + raise ValueError(self._mount_error()) + local_path = f"/{remainder}" if remainder else "/" + return mount, local_path + + @staticmethod + def _prefix_mount_path(mount: str, local_path: str) -> str: + if local_path == "/": + return f"/{mount}" + return f"/{mount}{local_path}" + + @staticmethod + def _get_value(item: Any, key: str) -> Any: + if isinstance(item, dict): + return item.get(key) + return getattr(item, key, None) + + @classmethod + def _get_str(cls, item: Any, key: str) -> str: + value = cls._get_value(item, key) + return value if isinstance(value, str) else "" + + @classmethod + def _get_int(cls, item: Any, key: str) -> int: + value = cls._get_value(item, key) + return int(value) if isinstance(value, int | float) else 0 + + @classmethod + def _get_bool(cls, item: Any, key: str) -> bool: + value = cls._get_value(item, key) + return bool(value) + + def _list_mount_roots(self) -> list[FileInfo]: + return [ + FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0") + for mount in self._mount_order + ] + + def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]: + transformed: list[FileInfo] = [] + for info in infos: + transformed.append( + FileInfo( + path=self._prefix_mount_path(mount, self._get_str(info, "path")), + is_dir=self._get_bool(info, "is_dir"), + size=self._get_int(info, "size"), + modified_at=self._get_str(info, "modified_at"), + ) + ) + return transformed + + def ls_info(self, path: str) -> list[FileInfo]: + if path == "/": + return self._list_mount_roots() + try: + mount, local_path = self._split_mount_path(path) + except ValueError: + return [] + return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path)) + + async def als_info(self, path: str) -> list[FileInfo]: + return await asyncio.to_thread(self.ls_info, path) + + def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return f"Error: {exc}" + return self._mount_to_backend[mount].read(local_path, offset, limit) + + async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: + return await asyncio.to_thread(self.read, file_path, offset, limit) + + def read_raw(self, file_path: str) -> str: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return f"Error: {exc}" + return self._mount_to_backend[mount].read_raw(local_path) + + async def aread_raw(self, file_path: str) -> str: + return await asyncio.to_thread(self.read_raw, file_path) + + def write(self, file_path: str, content: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + result = self._mount_to_backend[mount].write(local_path, content) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def awrite(self, file_path: str, content: str) -> WriteResult: + return await asyncio.to_thread(self.write, file_path, content) + + def edit( + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return EditResult(error=f"Error: {exc}") + result = self._mount_to_backend[mount].edit( + local_path, old_string, new_string, replace_all + ) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def aedit( + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + return await asyncio.to_thread( + self.edit, file_path, old_string, new_string, replace_all + ) + + def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: + if path == "/": + prefixed_results: list[FileInfo] = [] + if pattern.startswith("/"): + mount, _, remainder = pattern.lstrip("/").partition("/") + backend = self._mount_to_backend.get(mount) + if not backend: + return [] + local_pattern = f"/{remainder}" if remainder else "/" + return self._transform_infos( + mount, backend.glob_info(local_pattern, path="/") + ) + for mount, backend in self._mount_to_backend.items(): + prefixed_results.extend( + self._transform_infos(mount, backend.glob_info(pattern, path="/")) + ) + return prefixed_results + + try: + mount, local_path = self._split_mount_path(path) + except ValueError: + return [] + return self._transform_infos( + mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path) + ) + + async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: + return await asyncio.to_thread(self.glob_info, pattern, path) + + def grep_raw( + self, pattern: str, path: str | None = None, glob: str | None = None + ) -> list[GrepMatch] | str: + if not pattern: + return "Error: pattern cannot be empty" + if path is None or path == "/": + all_matches: list[GrepMatch] = [] + for mount, backend in self._mount_to_backend.items(): + result = backend.grep_raw(pattern, path="/", glob=glob) + if isinstance(result, str): + return result + all_matches.extend( + [ + GrepMatch( + path=self._prefix_mount_path(mount, self._get_str(match, "path")), + line=self._get_int(match, "line"), + text=self._get_str(match, "text"), + ) + for match in result + ] + ) + return all_matches + try: + mount, local_path = self._split_mount_path(path) + except ValueError as exc: + return f"Error: {exc}" + + result = self._mount_to_backend[mount].grep_raw( + pattern, path=local_path, glob=glob + ) + if isinstance(result, str): + return result + return [ + GrepMatch( + path=self._prefix_mount_path(mount, self._get_str(match, "path")), + line=self._get_int(match, "line"), + text=self._get_str(match, "text"), + ) + for match in result + ] + + async def agrep_raw( + self, pattern: str, path: str | None = None, glob: str | None = None + ) -> list[GrepMatch] | str: + return await asyncio.to_thread(self.grep_raw, pattern, path, glob) + + def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + grouped: dict[str, list[tuple[str, bytes]]] = {} + invalid: list[FileUploadResponse] = [] + for virtual_path, content in files: + try: + mount, local_path = self._split_mount_path(virtual_path) + except ValueError: + invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH)) + continue + grouped.setdefault(mount, []).append((local_path, content)) + + responses = list(invalid) + for mount, mount_files in grouped.items(): + result = self._mount_to_backend[mount].upload_files(mount_files) + responses.extend( + [ + FileUploadResponse( + path=self._prefix_mount_path(mount, self._get_str(item, "path")), + error=self._get_str(item, "error") or None, + ) + for item in result + ] + ) + return responses + + async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + return await asyncio.to_thread(self.upload_files, files) + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + grouped: dict[str, list[str]] = {} + invalid: list[FileDownloadResponse] = [] + for virtual_path in paths: + try: + mount, local_path = self._split_mount_path(virtual_path) + except ValueError: + invalid.append( + FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH) + ) + continue + grouped.setdefault(mount, []).append(local_path) + + responses = list(invalid) + for mount, mount_paths in grouped.items(): + result = self._mount_to_backend[mount].download_files(mount_paths) + responses.extend( + [ + FileDownloadResponse( + path=self._prefix_mount_path(mount, self._get_str(item, "path")), + content=self._get_value(item, "content"), + error=self._get_str(item, "error") or None, + ) + for item in result + ] + ) + return responses + + async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]: + return await asyncio.to_thread(self.download_files, paths) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 548bd1402..e1a26ba04 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -73,7 +73,7 @@ def _resolve_filesystem_selection( *, mode: str, client_platform: str, - local_root: str | None, + local_roots: list[str] | None, ) -> FilesystemSelection: """Validate and normalize filesystem mode settings from request payload.""" try: @@ -96,21 +96,29 @@ def _resolve_filesystem_selection( status_code=400, detail="desktop_local_folder mode is only available on desktop runtime.", ) - if not local_root or not local_root.strip(): + normalized_roots: list[str] = [] + for root in local_roots or []: + trimmed = root.strip() + if trimmed and trimmed not in normalized_roots: + normalized_roots.append(trimmed) + if not normalized_roots: raise HTTPException( status_code=400, - detail="local_filesystem_root is required for desktop_local_folder mode.", + detail=( + "local_filesystem_roots must include at least one root for " + "desktop_local_folder mode." + ), ) return FilesystemSelection( mode=resolved_mode, client_platform=resolved_platform, - local_root_path=local_root.strip(), + local_root_paths=tuple(normalized_roots), ) return FilesystemSelection( mode=FilesystemMode.CLOUD, client_platform=resolved_platform, - local_root_path=None, + local_root_paths=(), ) @@ -1188,7 +1196,7 @@ async def handle_new_chat( filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, - local_root=request.local_filesystem_root, + local_roots=request.local_filesystem_roots, ) # Get search space to check LLM config preferences @@ -1310,7 +1318,7 @@ async def regenerate_response( filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, - local_root=request.local_filesystem_root, + local_roots=request.local_filesystem_roots, ) # Get the checkpointer and state history @@ -1569,7 +1577,7 @@ async def resume_chat( filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, - local_root=request.local_filesystem_root, + local_roots=request.local_filesystem_roots, ) search_space_result = await session.execute( diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 593127c7e..38cdf0b28 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -186,7 +186,7 @@ class NewChatRequest(BaseModel): ) filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" - local_filesystem_root: str | None = None + local_filesystem_roots: list[str] | None = None class RegenerateRequest(BaseModel): @@ -209,7 +209,7 @@ class RegenerateRequest(BaseModel): disabled_tools: list[str] | None = None filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" - local_filesystem_root: str | None = None + local_filesystem_roots: list[str] | None = None # ============================================================================= @@ -235,7 +235,7 @@ class ResumeRequest(BaseModel): decisions: list[ResumeDecision] filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" - local_filesystem_root: str | None = None + local_filesystem_roots: list[str] | None = None # ============================================================================= diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py index 2377307f8..a1867ff6c 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py @@ -8,7 +8,9 @@ from app.agents.new_chat.filesystem_selection import ( FilesystemMode, FilesystemSelection, ) -from app.agents.new_chat.middleware.local_folder_backend import LocalFolderBackend +from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( + MultiRootLocalFolderBackend, +) pytestmark = pytest.mark.unit @@ -17,16 +19,16 @@ class _RuntimeStub: state = {"files": {}} -def test_backend_resolver_returns_local_backend_for_local_mode(tmp_path: Path): +def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: Path): selection = FilesystemSelection( mode=FilesystemMode.DESKTOP_LOCAL_FOLDER, client_platform=ClientPlatform.DESKTOP, - local_root_path=str(tmp_path), + local_root_paths=(str(tmp_path),), ) resolver = build_backend_resolver(selection) backend = resolver(_RuntimeStub()) - assert isinstance(backend, LocalFolderBackend) + assert isinstance(backend, MultiRootLocalFolderBackend) def test_backend_resolver_uses_cloud_mode_by_default(): @@ -35,3 +37,19 @@ def test_backend_resolver_uses_cloud_mode_by_default(): # StateBackend class name check keeps this test decoupled # from internal deepagents runtime class identity. assert backend.__class__.__name__ == "StateBackend" + + +def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path): + root_one = tmp_path / "resume" + root_two = tmp_path / "notes" + root_one.mkdir() + root_two.mkdir() + selection = FilesystemSelection( + mode=FilesystemMode.DESKTOP_LOCAL_FOLDER, + client_platform=ClientPlatform.DESKTOP, + local_root_paths=(str(root_one), str(root_two)), + ) + resolver = build_backend_resolver(selection) + + backend = resolver(_RuntimeStub()) + assert isinstance(backend, MultiRootLocalFolderBackend) From 3ee2683391fb82c50c74b73a5b0522845e682100 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:45:13 +0530 Subject: [PATCH 135/299] feat(filesystem): propagate localRootPaths across desktop and web API --- surfsense_desktop/src/ipc/handlers.ts | 2 +- .../src/modules/agent-filesystem.ts | 100 ++++++++++++++---- surfsense_desktop/src/preload.ts | 2 +- .../new-chat/[[...chat_id]]/page.tsx | 8 +- surfsense_web/lib/agent-filesystem.ts | 7 +- surfsense_web/types/window.d.ts | 4 +- 6 files changed, 93 insertions(+), 30 deletions(-) diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index cc84a46e0..247d171f5 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -228,7 +228,7 @@ export function registerIpcHandlers(): void { ipcMain.handle( IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, - (_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPath?: string | null }) => + (_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) => setAgentFilesystemSettings(settings) ); diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index 9dfe79fb0..afad98f24 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -1,16 +1,17 @@ import { app, dialog } from "electron"; -import { mkdir, readFile, writeFile } from "node:fs/promises"; +import { access, mkdir, readFile, writeFile } from "node:fs/promises"; import { dirname, isAbsolute, join, relative, resolve } from "node:path"; export type AgentFilesystemMode = "cloud" | "desktop_local_folder"; export interface AgentFilesystemSettings { mode: AgentFilesystemMode; - localRootPath: string | null; + localRootPaths: string[]; updatedAt: string; } const SETTINGS_FILENAME = "agent-filesystem-settings.json"; +const MAX_LOCAL_ROOTS = 5; function getSettingsPath(): string { return join(app.getPath("userData"), SETTINGS_FILENAME); @@ -19,11 +20,28 @@ function getSettingsPath(): string { function getDefaultSettings(): AgentFilesystemSettings { return { mode: "cloud", - localRootPath: null, + localRootPaths: [], updatedAt: new Date().toISOString(), }; } +function normalizeLocalRootPaths(paths: unknown): string[] { + if (!Array.isArray(paths)) { + return []; + } + const uniquePaths = new Set(); + for (const path of paths) { + if (typeof path !== "string") continue; + const trimmed = path.trim(); + if (!trimmed) continue; + uniquePaths.add(trimmed); + if (uniquePaths.size >= MAX_LOCAL_ROOTS) { + break; + } + } + return [...uniquePaths]; +} + export async function getAgentFilesystemSettings(): Promise { try { const raw = await readFile(getSettingsPath(), "utf8"); @@ -33,7 +51,7 @@ export async function getAgentFilesystemSettings(): Promise> + settings: { + mode?: AgentFilesystemMode; + localRootPaths?: string[] | null; + } ): Promise { const current = await getAgentFilesystemSettings(); const nextMode = @@ -51,8 +72,10 @@ export async function setAgentFilesystemSettings( : current.mode; const next: AgentFilesystemSettings = { mode: nextMode, - localRootPath: - settings.localRootPath === undefined ? current.localRootPath : settings.localRootPath, + localRootPaths: + settings.localRootPaths === undefined + ? current.localRootPaths + : normalizeLocalRootPaths(settings.localRootPaths ?? []), updatedAt: new Date().toISOString(), }; @@ -101,20 +124,45 @@ function toVirtualPath(rootPath: string, absolutePath: string): string { async function resolveCurrentRootPath(): Promise { const settings = await getAgentFilesystemSettings(); - if (!settings.localRootPath) { - throw new Error("No local filesystem root selected"); + if (settings.localRootPaths.length === 0) { + throw new Error("No local filesystem roots selected"); } - return settings.localRootPath; + return settings.localRootPaths[0]; +} + +async function resolveCurrentRootPaths(): Promise { + const settings = await getAgentFilesystemSettings(); + if (settings.localRootPaths.length === 0) { + throw new Error("No local filesystem roots selected"); + } + return settings.localRootPaths; } export async function readAgentLocalFileText( virtualPath: string ): Promise<{ path: string; content: string }> { - const rootPath = await resolveCurrentRootPath(); - const absolutePath = resolveVirtualPath(rootPath, virtualPath); - const content = await readFile(absolutePath, "utf8"); + const rootPaths = await resolveCurrentRootPaths(); + for (const rootPath of rootPaths) { + const absolutePath = resolveVirtualPath(rootPath, virtualPath); + try { + const content = await readFile(absolutePath, "utf8"); + return { + path: toVirtualPath(rootPath, absolutePath), + content, + }; + } catch (error) { + if ((error as NodeJS.ErrnoException).code === "ENOENT") { + continue; + } + throw error; + } + } + // Keep the same relative virtual path in the error context. + const fallbackRootPath = await resolveCurrentRootPath(); + const fallbackAbsolutePath = resolveVirtualPath(fallbackRootPath, virtualPath); + const content = await readFile(fallbackAbsolutePath, "utf8"); return { - path: toVirtualPath(rootPath, absolutePath), + path: toVirtualPath(fallbackRootPath, fallbackAbsolutePath), content, }; } @@ -123,11 +171,25 @@ export async function writeAgentLocalFileText( virtualPath: string, content: string ): Promise<{ path: string }> { - const rootPath = await resolveCurrentRootPath(); - const absolutePath = resolveVirtualPath(rootPath, virtualPath); - await mkdir(dirname(absolutePath), { recursive: true }); - await writeFile(absolutePath, content, "utf8"); + const rootPaths = await resolveCurrentRootPaths(); + let selectedRootPath = rootPaths[0]; + let selectedAbsolutePath = resolveVirtualPath(selectedRootPath, virtualPath); + + for (const rootPath of rootPaths) { + const absolutePath = resolveVirtualPath(rootPath, virtualPath); + try { + await access(absolutePath); + selectedRootPath = rootPath; + selectedAbsolutePath = absolutePath; + break; + } catch { + // Keep searching for an existing file path across selected roots. + } + } + + await mkdir(dirname(selectedAbsolutePath), { recursive: true }); + await writeFile(selectedAbsolutePath, content, "utf8"); return { - path: toVirtualPath(rootPath, absolutePath), + path: toVirtualPath(selectedRootPath, selectedAbsolutePath), }; } diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index 9fc213bfa..f7aaf9633 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -110,7 +110,7 @@ contextBridge.exposeInMainWorld('electronAPI', { ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS), setAgentFilesystemSettings: (settings: { mode?: "cloud" | "desktop_local_folder"; - localRootPath?: string | null; + localRootPaths?: string[] | null; }) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings), pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT), }); diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index bdb77ade2..616637a49 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -660,7 +660,7 @@ export default function NewChatPage() { const selection = await getAgentFilesystemSelection(); if ( selection.filesystem_mode === "desktop_local_folder" && - !selection.local_filesystem_root + (!selection.local_filesystem_roots || selection.local_filesystem_roots.length === 0) ) { toast.error("Select a local folder before using Local Folder mode."); return; @@ -702,7 +702,7 @@ export default function NewChatPage() { search_space_id: searchSpaceId, filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, - local_filesystem_root: selection.local_filesystem_root, + local_filesystem_roots: selection.local_filesystem_roots, messages: messageHistory, mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, mentioned_surfsense_doc_ids: hasSurfsenseDocIds @@ -1098,7 +1098,7 @@ export default function NewChatPage() { decisions, filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, - local_filesystem_root: selection.local_filesystem_root, + local_filesystem_roots: selection.local_filesystem_roots, }), signal: controller.signal, }); @@ -1435,7 +1435,7 @@ export default function NewChatPage() { disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, - local_filesystem_root: selection.local_filesystem_root, + local_filesystem_roots: selection.local_filesystem_roots, }), signal: controller.signal, }); diff --git a/surfsense_web/lib/agent-filesystem.ts b/surfsense_web/lib/agent-filesystem.ts index 6bfb5d131..c9096a294 100644 --- a/surfsense_web/lib/agent-filesystem.ts +++ b/surfsense_web/lib/agent-filesystem.ts @@ -4,7 +4,7 @@ export type ClientPlatform = "web" | "desktop"; export interface AgentFilesystemSelection { filesystem_mode: AgentFilesystemMode; client_platform: ClientPlatform; - local_filesystem_root?: string; + local_filesystem_roots?: string[]; } const DEFAULT_SELECTION: AgentFilesystemSelection = { @@ -24,11 +24,12 @@ export async function getAgentFilesystemSelection(): Promise Promise; setAgentFilesystemSettings: (settings: { mode?: AgentFilesystemMode; - localRootPath?: string | null; + localRootPaths?: string[] | null; }) => Promise; pickAgentFilesystemRoot: () => Promise; } From a250f971622e5f3a3cd4c32eb7ecf45c3682f186 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:46:32 +0530 Subject: [PATCH 136/299] feat(thread): support selecting and managing multiple local folders --- .../components/assistant-ui/thread.tsx | 128 +++++++++++++++--- 1 file changed, 110 insertions(+), 18 deletions(-) diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 9df41ee55..6fde33061 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -110,11 +110,15 @@ const COMPOSER_PLACEHOLDER = "Ask anything, type / for prompts, type @ to mentio type ComposerFilesystemSettings = { mode: "cloud" | "desktop_local_folder"; - localRootPath: string | null; + localRootPaths: string[]; updatedAt: string; }; const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1"; +const MAX_LOCAL_FILESYSTEM_ROOTS = 5; + +const getFolderDisplayName = (rootPath: string): string => + rootPath.split(/[\\/]/).at(-1) || rootPath; export const Thread: FC = () => { return ; @@ -388,6 +392,7 @@ const Composer: FC = () => { null ); const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); + const [localFoldersOpen, setLocalFoldersOpen] = useState(false); const [pendingLocalPath, setPendingLocalPath] = useState(null); const [clipboardInitialText, setClipboardInitialText] = useState(); const clipboardLoadedRef = useRef(false); @@ -414,7 +419,7 @@ const Composer: FC = () => { if (!mounted) return; setFilesystemSettings({ mode: "cloud", - localRootPath: null, + localRootPaths: [], updatedAt: new Date().toISOString(), }); }); @@ -431,16 +436,27 @@ const Composer: FC = () => { } }, []); + const localRootPaths = filesystemSettings?.localRootPaths ?? []; + const primaryLocalRootPath = localRootPaths[0] ?? null; + const extraLocalRootCount = Math.max(0, localRootPaths.length - 1); + const canAddMoreLocalRoots = localRootPaths.length < MAX_LOCAL_FILESYSTEM_ROOTS; + const applyLocalRootPath = useCallback( async (path: string) => { if (!electronAPI?.setAgentFilesystemSettings) return; + const nextLocalRootPaths = [...localRootPaths, path] + .filter((rootPath, index, allPaths) => allPaths.indexOf(rootPath) === index) + .slice(0, MAX_LOCAL_FILESYSTEM_ROOTS); + if (nextLocalRootPaths.length === localRootPaths.length) { + return; + } const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder", - localRootPath: path, + localRootPaths: nextLocalRootPaths, }); setFilesystemSettings(updated); }, - [electronAPI] + [electronAPI, localRootPaths] ); const runSwitchToLocalMode = useCallback(async () => { @@ -467,6 +483,7 @@ const Composer: FC = () => { ); const handlePickFilesystemRoot = useCallback(async () => { + if (!canAddMoreLocalRoots) return; if (hasLocalFilesystemTrust()) { await runPickLocalRoot(); return; @@ -476,13 +493,25 @@ const Composer: FC = () => { if (!picked) return; setPendingLocalPath(picked); setLocalTrustDialogOpen(true); - }, [electronAPI, hasLocalFilesystemTrust, runPickLocalRoot]); + }, [canAddMoreLocalRoots, electronAPI, hasLocalFilesystemTrust, runPickLocalRoot]); - const handleClearFilesystemRoot = useCallback(async () => { + const handleRemoveFilesystemRoot = useCallback( + async (rootPathToRemove: string) => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove), + }); + setFilesystemSettings(updated); + }, + [electronAPI, localRootPaths] + ); + + const handleClearFilesystemRoots = useCallback(async () => { if (!electronAPI?.setAgentFilesystemSettings) return; const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder", - localRootPath: null, + localRootPaths: [], }); setFilesystemSettings(updated); }, [electronAPI]); @@ -833,31 +862,89 @@ const Composer: FC = () => { {filesystemSettings.mode === "desktop_local_folder" && ( <>
-
- {filesystemSettings.localRootPath ? ( +
+ {primaryLocalRootPath ? ( <>
- {filesystemSettings.localRootPath.split("/").at(-1) || - filesystemSettings.localRootPath} + {getFolderDisplayName(primaryLocalRootPath)}
+ {extraLocalRootCount > 0 && ( + + + + + +
+ {localRootPaths.map((rootPath) => ( +
+ + + {getFolderDisplayName(rootPath)} + + +
+ ))} +
+ +
+
+
+
+ )} @@ -909,9 +1001,9 @@ const Composer: FC = () => { Local mode can read and edit files inside the folders you select. Continue only if you trust this workspace and its contents. - {(pendingLocalPath || filesystemSettings?.localRootPath) && ( + {(pendingLocalPath || primaryLocalRootPath) && ( - Folder path: {pendingLocalPath || filesystemSettings?.localRootPath} + Folder path: {pendingLocalPath || primaryLocalRootPath} )} From c1a07a093e46c760c370df05c01f5c126d198286 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:46:44 +0530 Subject: [PATCH 137/299] refactor(sidebar): use Monitor icon for system theme option --- .../components/layout/ui/sidebar/SidebarUserProfile.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx index 81fbeef91..acece2d5c 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx @@ -7,8 +7,8 @@ import { ExternalLink, Info, Languages, - Laptop, LogOut, + Monitor, Moon, Sun, UserCog, @@ -49,7 +49,7 @@ const LANGUAGES = [ const THEMES = [ { value: "light" as const, name: "Light", icon: Sun }, { value: "dark" as const, name: "Dark", icon: Moon }, - { value: "system" as const, name: "System", icon: Laptop }, + { value: "system" as const, name: "System", icon: Monitor }, ]; const LEARN_MORE_LINKS = [ From 1e9db6f26f12f399f9b94eed51184782ab7f7ae4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 02:12:30 +0530 Subject: [PATCH 138/299] feat(filesystem): enhance local mount path normalization and improve virtual path handling in agent filesystem --- .../agents/new_chat/middleware/filesystem.py | 41 ++++--- .../src/modules/agent-filesystem.ts | 110 ++++++++++++------ .../components/editor/source-code-editor.tsx | 2 +- 3 files changed, 96 insertions(+), 57 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 6c30b20ef..a086357af 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -782,6 +782,27 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return f"/{backend.default_mount()}" return "" + def _normalize_local_mount_path( + self, candidate: str, runtime: ToolRuntime[None, FilesystemState] + ) -> str: + backend = self._get_backend(runtime) + mount_prefix = self._default_mount_prefix(runtime) + if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend): + return candidate if candidate.startswith("/") else f"/{candidate.lstrip('/')}" + + mount_names = set(backend.list_mounts()) + if candidate.startswith("/"): + first_segment = candidate.lstrip("/").split("/", 1)[0] + if first_segment in mount_names: + return candidate + return f"{mount_prefix}{candidate}" + + relative = candidate.lstrip("/") + first_segment = relative.split("/", 1)[0] + if first_segment in mount_names: + return f"/{relative}" + return f"{mount_prefix}/{relative}" + def _get_contract_suggested_path( self, runtime: ToolRuntime[None, FilesystemState] ) -> str: @@ -790,11 +811,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if isinstance(suggested, str) and suggested.strip(): cleaned = suggested.strip() if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - mount_prefix = self._default_mount_prefix(runtime) - if mount_prefix and cleaned.startswith("/") and not cleaned.startswith( - f"{mount_prefix}/" - ): - return f"{mount_prefix}{cleaned}" + return self._normalize_local_mount_path(cleaned, runtime) return cleaned if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: mount_prefix = self._default_mount_prefix(runtime) @@ -811,19 +828,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not candidate: return self._get_contract_suggested_path(runtime) if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - backend = self._get_backend(runtime) - mount_prefix = self._default_mount_prefix(runtime) - if mount_prefix and not candidate.startswith("/"): - return f"{mount_prefix}/{candidate.lstrip('/')}" - if ( - mount_prefix - and isinstance(backend, MultiRootLocalFolderBackend) - and candidate.startswith("/") - ): - mount_names = backend.list_mounts() - first_segment = candidate.lstrip("/").split("/", 1)[0] - if first_segment not in mount_names: - return f"{mount_prefix}{candidate}" + return self._normalize_local_mount_path(candidate, runtime) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index afad98f24..2bf0101d6 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -122,12 +122,55 @@ function toVirtualPath(rootPath: string, absolutePath: string): string { return `/${rel.replace(/\\/g, "/")}`; } -async function resolveCurrentRootPath(): Promise { - const settings = await getAgentFilesystemSettings(); - if (settings.localRootPaths.length === 0) { - throw new Error("No local filesystem roots selected"); +type LocalRootMount = { + mount: string; + rootPath: string; +}; + +function buildRootMounts(rootPaths: string[]): LocalRootMount[] { + const mounts: LocalRootMount[] = []; + const usedMounts = new Set(); + for (const rawRootPath of rootPaths) { + const normalizedRoot = resolve(rawRootPath); + const baseMount = normalizedRoot.split(/[\\/]/).at(-1) || "root"; + let mount = baseMount; + let suffix = 2; + while (usedMounts.has(mount)) { + mount = `${baseMount}-${suffix}`; + suffix += 1; + } + usedMounts.add(mount); + mounts.push({ mount, rootPath: normalizedRoot }); } - return settings.localRootPaths[0]; + return mounts; +} + +function parseMountedVirtualPath(virtualPath: string): { + mount: string; + subPath: string; +} { + if (!virtualPath.startsWith("/")) { + throw new Error("Path must start with '/'"); + } + const trimmed = virtualPath.replace(/^\/+/, ""); + if (!trimmed) { + throw new Error("Path must include a mounted root segment"); + } + const [mount, ...rest] = trimmed.split("/"); + const remainder = rest.join("/"); + if (!remainder) { + throw new Error("Path must include a file path under the mounted root"); + } + return { mount, subPath: `/${remainder}` }; +} + +function findMountByName(mounts: LocalRootMount[], mountName: string): LocalRootMount | undefined { + return mounts.find((entry) => entry.mount === mountName); +} + +function toMountedVirtualPath(mount: string, rootPath: string, absolutePath: string): string { + const relativePath = toVirtualPath(rootPath, absolutePath); + return `/${mount}${relativePath}`; } async function resolveCurrentRootPaths(): Promise { @@ -142,27 +185,18 @@ export async function readAgentLocalFileText( virtualPath: string ): Promise<{ path: string; content: string }> { const rootPaths = await resolveCurrentRootPaths(); - for (const rootPath of rootPaths) { - const absolutePath = resolveVirtualPath(rootPath, virtualPath); - try { - const content = await readFile(absolutePath, "utf8"); - return { - path: toVirtualPath(rootPath, absolutePath), - content, - }; - } catch (error) { - if ((error as NodeJS.ErrnoException).code === "ENOENT") { - continue; - } - throw error; - } + const mounts = buildRootMounts(rootPaths); + const { mount, subPath } = parseMountedVirtualPath(virtualPath); + const rootMount = findMountByName(mounts, mount); + if (!rootMount) { + throw new Error( + `Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}` + ); } - // Keep the same relative virtual path in the error context. - const fallbackRootPath = await resolveCurrentRootPath(); - const fallbackAbsolutePath = resolveVirtualPath(fallbackRootPath, virtualPath); - const content = await readFile(fallbackAbsolutePath, "utf8"); + const absolutePath = resolveVirtualPath(rootMount.rootPath, subPath); + const content = await readFile(absolutePath, "utf8"); return { - path: toVirtualPath(fallbackRootPath, fallbackAbsolutePath), + path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, absolutePath), content, }; } @@ -172,24 +206,24 @@ export async function writeAgentLocalFileText( content: string ): Promise<{ path: string }> { const rootPaths = await resolveCurrentRootPaths(); - let selectedRootPath = rootPaths[0]; - let selectedAbsolutePath = resolveVirtualPath(selectedRootPath, virtualPath); - - for (const rootPath of rootPaths) { - const absolutePath = resolveVirtualPath(rootPath, virtualPath); - try { - await access(absolutePath); - selectedRootPath = rootPath; - selectedAbsolutePath = absolutePath; - break; - } catch { - // Keep searching for an existing file path across selected roots. - } + const mounts = buildRootMounts(rootPaths); + const { mount, subPath } = parseMountedVirtualPath(virtualPath); + const rootMount = findMountByName(mounts, mount); + if (!rootMount) { + throw new Error( + `Unknown mounted root '${mount}'. Available roots: ${mounts.map((entry) => `/${entry.mount}`).join(", ")}` + ); } + let selectedAbsolutePath = resolveVirtualPath(rootMount.rootPath, subPath); + try { + await access(selectedAbsolutePath); + } catch { + // New files are created under the selected mounted root. + } await mkdir(dirname(selectedAbsolutePath), { recursive: true }); await writeFile(selectedAbsolutePath, content, "utf8"); return { - path: toVirtualPath(selectedRootPath, selectedAbsolutePath), + path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, selectedAbsolutePath), }; } diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index 11f9266b6..c2d77be60 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -89,7 +89,7 @@ export function SourceCodeEditor({ onChange={(next) => onChange(next ?? "")} loading={
- +
} beforeMount={(monaco) => { From 17f9ee4b592d3ba696333c818dbcf51f6320a59d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 02:33:57 +0530 Subject: [PATCH 139/299] refactor(icons): replace 'Pen' icon with 'Pencil' across various components for consistency --- .../user-settings/components/MemoryContent.tsx | 4 ++-- .../user-settings/components/PromptsContent.tsx | 4 ++-- surfsense_web/components/assistant-ui/user-message.tsx | 4 ++-- .../chat-comments/comment-item/comment-actions.tsx | 4 ++-- surfsense_web/components/documents/DocumentNode.tsx | 6 +++--- surfsense_web/components/documents/FolderNode.tsx | 6 +++--- .../components/layout/ui/sidebar/AllPrivateChatsSidebar.tsx | 4 ++-- .../components/layout/ui/sidebar/AllSharedChatsSidebar.tsx | 4 ++-- surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx | 4 ++-- surfsense_web/components/layout/ui/sidebar/Sidebar.tsx | 4 ++-- .../components/layout/ui/tabs/DocumentTabContent.tsx | 4 ++-- surfsense_web/components/settings/team-memory-manager.tsx | 4 ++-- .../tool-ui/confluence/create-confluence-page.tsx | 4 ++-- .../tool-ui/confluence/update-confluence-page.tsx | 4 ++-- surfsense_web/components/tool-ui/dropbox/create-file.tsx | 4 ++-- surfsense_web/components/tool-ui/generic-hitl-approval.tsx | 4 ++-- surfsense_web/components/tool-ui/gmail/create-draft.tsx | 4 ++-- surfsense_web/components/tool-ui/gmail/send-email.tsx | 4 ++-- surfsense_web/components/tool-ui/gmail/update-draft.tsx | 4 ++-- .../components/tool-ui/google-calendar/create-event.tsx | 4 ++-- .../components/tool-ui/google-calendar/update-event.tsx | 4 ++-- .../components/tool-ui/google-drive/create-file.tsx | 4 ++-- surfsense_web/components/tool-ui/jira/create-jira-issue.tsx | 4 ++-- surfsense_web/components/tool-ui/jira/update-jira-issue.tsx | 4 ++-- .../components/tool-ui/linear/create-linear-issue.tsx | 4 ++-- .../components/tool-ui/linear/update-linear-issue.tsx | 4 ++-- .../components/tool-ui/notion/create-notion-page.tsx | 4 ++-- .../components/tool-ui/notion/update-notion-page.tsx | 4 ++-- surfsense_web/components/tool-ui/onedrive/create-file.tsx | 4 ++-- surfsense_web/components/ui/mode-toolbar-button.tsx | 4 ++-- 30 files changed, 62 insertions(+), 62 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MemoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MemoryContent.tsx index ef17e5a89..3d0550b6c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MemoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MemoryContent.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react"; +import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pencil } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; @@ -241,7 +241,7 @@ export function MemoryContent() { onClick={openInput} className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm" > - + )}
diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx index 1e7087afc..c78d4f9f0 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { AlertTriangle, Globe, Lock, PenLine, Sparkles, Trash2 } from "lucide-react"; +import { AlertTriangle, Globe, Lock, Pencil, Sparkles, Trash2 } from "lucide-react"; import { useCallback, useState } from "react"; import { toast } from "sonner"; import { @@ -308,7 +308,7 @@ export function PromptsContent() { className="size-7" onClick={() => handleEdit(prompt)} > - + )} diff --git a/surfsense_web/components/settings/team-memory-manager.tsx b/surfsense_web/components/settings/team-memory-manager.tsx index 67369879b..371527530 100644 --- a/surfsense_web/components/settings/team-memory-manager.tsx +++ b/surfsense_web/components/settings/team-memory-manager.tsx @@ -2,7 +2,7 @@ import { useQuery, useQueryClient } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; -import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pen } from "lucide-react"; +import { ArrowUp, ChevronDown, ClipboardCopy, Download, Info, Pencil } from "lucide-react"; import { useEffect, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; @@ -247,7 +247,7 @@ export function TeamMemoryManager({ searchSpaceId }: TeamMemoryManagerProps) { onClick={openInput} className="absolute bottom-3 right-3 z-10 h-[54px] w-[54px] rounded-full border bg-muted/60 backdrop-blur-sm shadow-sm" > - + )}
diff --git a/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx b/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx index 5344527f9..1bef1f008 100644 --- a/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx +++ b/surfsense_web/components/tool-ui/confluence/create-confluence-page.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -222,7 +222,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx b/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx index 2038f7a0e..c30357fb6 100644 --- a/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx +++ b/surfsense_web/components/tool-ui/confluence/update-confluence-page.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -241,7 +241,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/dropbox/create-file.tsx b/surfsense_web/components/tool-ui/dropbox/create-file.tsx index 02eae2c83..f76a45f62 100644 --- a/surfsense_web/components/tool-ui/dropbox/create-file.tsx +++ b/surfsense_web/components/tool-ui/dropbox/create-file.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, FileIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, FileIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -224,7 +224,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index 809b76c38..d4ee61eeb 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -1,7 +1,7 @@ "use client"; import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; @@ -167,7 +167,7 @@ function GenericApprovalCard({ className="rounded-lg text-muted-foreground -mt-1 -mr-2" onClick={() => setIsEditing(true)} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/gmail/create-draft.tsx b/surfsense_web/components/tool-ui/gmail/create-draft.tsx index cfe61351a..a00760ca3 100644 --- a/surfsense_web/components/tool-ui/gmail/create-draft.tsx +++ b/surfsense_web/components/tool-ui/gmail/create-draft.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen, UserIcon, UsersIcon } from "lucide-react"; +import { CornerDownLeftIcon, Pencil, UserIcon, UsersIcon } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; @@ -251,7 +251,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/gmail/send-email.tsx b/surfsense_web/components/tool-ui/gmail/send-email.tsx index a21ece7b3..c22045fa1 100644 --- a/surfsense_web/components/tool-ui/gmail/send-email.tsx +++ b/surfsense_web/components/tool-ui/gmail/send-email.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, MailIcon, Pen, UserIcon, UsersIcon } from "lucide-react"; +import { CornerDownLeftIcon, MailIcon, Pencil, UserIcon, UsersIcon } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; @@ -250,7 +250,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/gmail/update-draft.tsx b/surfsense_web/components/tool-ui/gmail/update-draft.tsx index 0cbf338d7..b8c8c10f6 100644 --- a/surfsense_web/components/tool-ui/gmail/update-draft.tsx +++ b/surfsense_web/components/tool-ui/gmail/update-draft.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, MailIcon, Pen, UserIcon, UsersIcon } from "lucide-react"; +import { CornerDownLeftIcon, MailIcon, Pencil, UserIcon, UsersIcon } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; @@ -283,7 +283,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/google-calendar/create-event.tsx b/surfsense_web/components/tool-ui/google-calendar/create-event.tsx index 40a9f0106..9427c989b 100644 --- a/surfsense_web/components/tool-ui/google-calendar/create-event.tsx +++ b/surfsense_web/components/tool-ui/google-calendar/create-event.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { ClockIcon, CornerDownLeftIcon, GlobeIcon, MapPinIcon, Pen, UsersIcon } from "lucide-react"; +import { ClockIcon, CornerDownLeftIcon, GlobeIcon, MapPinIcon, Pencil, UsersIcon } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; @@ -332,7 +332,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/google-calendar/update-event.tsx b/surfsense_web/components/tool-ui/google-calendar/update-event.tsx index cd6ec0618..649174245 100644 --- a/surfsense_web/components/tool-ui/google-calendar/update-event.tsx +++ b/surfsense_web/components/tool-ui/google-calendar/update-event.tsx @@ -7,7 +7,7 @@ import { ClockIcon, CornerDownLeftIcon, MapPinIcon, - Pen, + Pencil, UsersIcon, } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; @@ -415,7 +415,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/google-drive/create-file.tsx b/surfsense_web/components/tool-ui/google-drive/create-file.tsx index 638db3db9..b13089877 100644 --- a/surfsense_web/components/tool-ui/google-drive/create-file.tsx +++ b/surfsense_web/components/tool-ui/google-drive/create-file.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, FileIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, FileIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -240,7 +240,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/jira/create-jira-issue.tsx b/surfsense_web/components/tool-ui/jira/create-jira-issue.tsx index 91041d15e..6916f9fa0 100644 --- a/surfsense_web/components/tool-ui/jira/create-jira-issue.tsx +++ b/surfsense_web/components/tool-ui/jira/create-jira-issue.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -257,7 +257,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/jira/update-jira-issue.tsx b/surfsense_web/components/tool-ui/jira/update-jira-issue.tsx index f377563da..72e697532 100644 --- a/surfsense_web/components/tool-ui/jira/update-jira-issue.tsx +++ b/surfsense_web/components/tool-ui/jira/update-jira-issue.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -273,7 +273,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/linear/create-linear-issue.tsx b/surfsense_web/components/tool-ui/linear/create-linear-issue.tsx index 8abc7b50b..7d5098c3e 100644 --- a/surfsense_web/components/tool-ui/linear/create-linear-issue.tsx +++ b/surfsense_web/components/tool-ui/linear/create-linear-issue.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -269,7 +269,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/linear/update-linear-issue.tsx b/surfsense_web/components/tool-ui/linear/update-linear-issue.tsx index daadfbc63..2d6846cea 100644 --- a/surfsense_web/components/tool-ui/linear/update-linear-issue.tsx +++ b/surfsense_web/components/tool-ui/linear/update-linear-issue.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -332,7 +332,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/notion/create-notion-page.tsx b/surfsense_web/components/tool-ui/notion/create-notion-page.tsx index 8c93c7648..b16a1d8cd 100644 --- a/surfsense_web/components/tool-ui/notion/create-notion-page.tsx +++ b/surfsense_web/components/tool-ui/notion/create-notion-page.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -219,7 +219,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/notion/update-notion-page.tsx b/surfsense_web/components/tool-ui/notion/update-notion-page.tsx index cf714b1b4..ef75c5d92 100644 --- a/surfsense_web/components/tool-ui/notion/update-notion-page.tsx +++ b/surfsense_web/components/tool-ui/notion/update-notion-page.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -196,7 +196,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/tool-ui/onedrive/create-file.tsx b/surfsense_web/components/tool-ui/onedrive/create-file.tsx index 8a64a6cf8..7621f152f 100644 --- a/surfsense_web/components/tool-ui/onedrive/create-file.tsx +++ b/surfsense_web/components/tool-ui/onedrive/create-file.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { CornerDownLeftIcon, FileIcon, Pen } from "lucide-react"; +import { CornerDownLeftIcon, FileIcon, Pencil } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { PlateEditor } from "@/components/editor/plate-editor"; @@ -209,7 +209,7 @@ function ApprovalCard({ }); }} > - + Edit )} diff --git a/surfsense_web/components/ui/mode-toolbar-button.tsx b/surfsense_web/components/ui/mode-toolbar-button.tsx index 37231991f..394eaf97c 100644 --- a/surfsense_web/components/ui/mode-toolbar-button.tsx +++ b/surfsense_web/components/ui/mode-toolbar-button.tsx @@ -1,6 +1,6 @@ "use client"; -import { BookOpenIcon, PenLineIcon } from "lucide-react"; +import { BookOpenIcon, Pencil } from "lucide-react"; import { usePlateState } from "platejs/react"; import { ToolbarButton } from "./toolbar"; @@ -13,7 +13,7 @@ export function ModeToolbarButton() { tooltip={readOnly ? "Click to edit" : "Click to view"} onClick={() => setReadOnly(!readOnly)} > - {readOnly ? : } + {readOnly ? : } ); } From 2618205749ebcbe532561a97868e04164c57bdd8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 03:52:39 +0530 Subject: [PATCH 140/299] refactor(thread): remove unused filesystem settings and related logic from Composer component --- .../components/assistant-ui/thread.tsx | 361 ------------------ 1 file changed, 361 deletions(-) diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 6fde33061..2ec422fbf 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -12,15 +12,11 @@ import { AlertCircle, ArrowDownIcon, ArrowUpIcon, - Check, ChevronDown, ChevronUp, Clipboard, Dot, - Folder, - FolderPlus, Globe, - Laptop, Plus, Settings2, SquareIcon, @@ -70,16 +66,6 @@ import { } from "@/components/new-chat/document-mention-picker"; import { PromptPicker, type PromptPickerRef } from "@/components/new-chat/prompt-picker"; import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; import { @@ -108,18 +94,6 @@ import { cn } from "@/lib/utils"; const COMPOSER_PLACEHOLDER = "Ask anything, type / for prompts, type @ to mention docs"; -type ComposerFilesystemSettings = { - mode: "cloud" | "desktop_local_folder"; - localRootPaths: string[]; - updatedAt: string; -}; - -const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1"; -const MAX_LOCAL_FILESYSTEM_ROOTS = 5; - -const getFolderDisplayName = (rootPath: string): string => - rootPath.split(/[\\/]/).at(-1) || rootPath; - export const Thread: FC = () => { return ; }; @@ -388,12 +362,6 @@ const Composer: FC = () => { }, []); const electronAPI = useElectronAPI(); - const [filesystemSettings, setFilesystemSettings] = useState( - null - ); - const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); - const [localFoldersOpen, setLocalFoldersOpen] = useState(false); - const [pendingLocalPath, setPendingLocalPath] = useState(null); const [clipboardInitialText, setClipboardInitialText] = useState(); const clipboardLoadedRef = useRef(false); useEffect(() => { @@ -406,116 +374,6 @@ const Composer: FC = () => { }); }, [electronAPI]); - useEffect(() => { - if (!electronAPI?.getAgentFilesystemSettings) return; - let mounted = true; - electronAPI - .getAgentFilesystemSettings() - .then((settings: ComposerFilesystemSettings) => { - if (!mounted) return; - setFilesystemSettings(settings); - }) - .catch(() => { - if (!mounted) return; - setFilesystemSettings({ - mode: "cloud", - localRootPaths: [], - updatedAt: new Date().toISOString(), - }); - }); - return () => { - mounted = false; - }; - }, [electronAPI]); - - const hasLocalFilesystemTrust = useCallback(() => { - try { - return window.localStorage.getItem(LOCAL_FILESYSTEM_TRUST_KEY) === "true"; - } catch { - return false; - } - }, []); - - const localRootPaths = filesystemSettings?.localRootPaths ?? []; - const primaryLocalRootPath = localRootPaths[0] ?? null; - const extraLocalRootCount = Math.max(0, localRootPaths.length - 1); - const canAddMoreLocalRoots = localRootPaths.length < MAX_LOCAL_FILESYSTEM_ROOTS; - - const applyLocalRootPath = useCallback( - async (path: string) => { - if (!electronAPI?.setAgentFilesystemSettings) return; - const nextLocalRootPaths = [...localRootPaths, path] - .filter((rootPath, index, allPaths) => allPaths.indexOf(rootPath) === index) - .slice(0, MAX_LOCAL_FILESYSTEM_ROOTS); - if (nextLocalRootPaths.length === localRootPaths.length) { - return; - } - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: nextLocalRootPaths, - }); - setFilesystemSettings(updated); - }, - [electronAPI, localRootPaths] - ); - - const runSwitchToLocalMode = useCallback(async () => { - if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder" }); - setFilesystemSettings(updated); - }, [electronAPI]); - - const runPickLocalRoot = useCallback(async () => { - if (!electronAPI?.pickAgentFilesystemRoot) return; - const picked = await electronAPI.pickAgentFilesystemRoot(); - if (!picked) return; - await applyLocalRootPath(picked); - }, [applyLocalRootPath, electronAPI]); - - const handleFilesystemModeChange = useCallback( - async (mode: "cloud" | "desktop_local_folder") => { - if (!electronAPI?.setAgentFilesystemSettings) return; - if (mode === "desktop_local_folder") return void runSwitchToLocalMode(); - const updated = await electronAPI.setAgentFilesystemSettings({ mode }); - setFilesystemSettings(updated); - }, - [electronAPI, runSwitchToLocalMode] - ); - - const handlePickFilesystemRoot = useCallback(async () => { - if (!canAddMoreLocalRoots) return; - if (hasLocalFilesystemTrust()) { - await runPickLocalRoot(); - return; - } - if (!electronAPI?.pickAgentFilesystemRoot) return; - const picked = await electronAPI.pickAgentFilesystemRoot(); - if (!picked) return; - setPendingLocalPath(picked); - setLocalTrustDialogOpen(true); - }, [canAddMoreLocalRoots, electronAPI, hasLocalFilesystemTrust, runPickLocalRoot]); - - const handleRemoveFilesystemRoot = useCallback( - async (rootPathToRemove: string) => { - if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove), - }); - setFilesystemSettings(updated); - }, - [electronAPI, localRootPaths] - ); - - const handleClearFilesystemRoots = useCallback(async () => { - if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: [], - }); - setFilesystemSettings(updated); - }, [electronAPI]); - const isThreadEmpty = useAuiState(({ thread }) => thread.isEmpty); const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); @@ -810,225 +668,6 @@ const Composer: FC = () => { currentUserId={currentUser?.id ?? null} members={members ?? []} /> - {electronAPI && filesystemSettings ? ( -
- - - - - - handleFilesystemModeChange("cloud")} - className="flex items-center justify-between" - > - - - Cloud - - {filesystemSettings.mode === "cloud" && } - - handleFilesystemModeChange("desktop_local_folder")} - className="flex items-center justify-between" - > - - - Local - - {filesystemSettings.mode === "desktop_local_folder" && ( - - )} - - - - - {filesystemSettings.mode === "desktop_local_folder" && ( - <> -
-
- {primaryLocalRootPath ? ( - <> -
- - - {getFolderDisplayName(primaryLocalRootPath)} - - -
- {extraLocalRootCount > 0 && ( - - - - - -
- {localRootPaths.map((rootPath) => ( -
- - - {getFolderDisplayName(rootPath)} - - -
- ))} -
- -
-
-
-
- )} - - - ) : ( - - )} -
- - )} -
- ) : null} - { - setLocalTrustDialogOpen(open); - if (!open) { - setPendingLocalPath(null); - } - }} - > - - - Trust this workspace? - - Local mode can read and edit files inside the folders you select. Continue only if - you trust this workspace and its contents. - - {(pendingLocalPath || primaryLocalRootPath) && ( - - Folder path: {pendingLocalPath || primaryLocalRootPath} - - )} - - - Cancel - { - try { - window.localStorage.setItem(LOCAL_FILESYSTEM_TRUST_KEY, "true"); - } catch {} - setLocalTrustDialogOpen(false); - const path = pendingLocalPath; - setPendingLocalPath(null); - if (path) { - await applyLocalRootPath(path); - } else { - await runPickLocalRoot(); - } - }} - > - I trust this workspace - - - - {showDocumentPopover && (
Date: Fri, 24 Apr 2026 03:55:24 +0530 Subject: [PATCH 141/299] feat(sidebar): implement local filesystem browser and enhance document sidebar with local folder management features --- .../layout/ui/sidebar/DocumentsSidebar.tsx | 466 +++++++++++++++--- .../ui/sidebar/LocalFilesystemBrowser.tsx | 271 ++++++++++ 2 files changed, 675 insertions(+), 62 deletions(-) create mode 100644 surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index daed8747d..5c955a53e 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -6,9 +6,12 @@ import { ChevronLeft, ChevronRight, FileText, + Folder, FolderClock, + Globe, Lock, Paperclip, + Search, Trash2, Unplug, Upload, @@ -59,7 +62,9 @@ import { import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; +import { Input } from "@/components/ui/input"; import { Spinner } from "@/components/ui/spinner"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { useAnonymousMode, useIsAnonymous } from "@/contexts/anonymous-mode"; import { useLoginGate } from "@/contexts/login-gate"; @@ -76,9 +81,31 @@ import { BACKEND_URL } from "@/lib/env-config"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; +import { LocalFilesystemBrowser } from "./LocalFilesystemBrowser"; import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel"; const NON_DELETABLE_DOCUMENT_TYPES: readonly string[] = ["SURFSENSE_DOCS"]; +const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1"; +const MAX_LOCAL_FILESYSTEM_ROOTS = 5; + +type FilesystemSettings = { + mode: "cloud" | "desktop_local_folder"; + localRootPaths: string[]; + updatedAt: string; +}; + +interface WatchedFolderEntry { + path: string; + name: string; + excludePatterns: string[]; + fileExtensions: string[] | null; + rootFolderId: number | null; + searchSpaceId: number; + active: boolean; +} + +const getFolderDisplayName = (rootPath: string): string => + rootPath.split(/[\\/]/).at(-1) || rootPath; const SHOWCASE_CONNECTORS = [ { type: "GOOGLE_DRIVE_CONNECTOR", label: "Google Drive" }, @@ -133,12 +160,119 @@ function AuthenticatedDocumentsSidebar({ const [search, setSearch] = useState(""); const debouncedSearch = useDebouncedValue(search, 250); + const [localSearch, setLocalSearch] = useState(""); + const debouncedLocalSearch = useDebouncedValue(localSearch, 250); + const localSearchInputRef = useRef(null); const [activeTypes, setActiveTypes] = useState([]); + const [filesystemSettings, setFilesystemSettings] = useState(null); + const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); + const [pendingLocalPath, setPendingLocalPath] = useState(null); const [watchedFolderIds, setWatchedFolderIds] = useState>(new Set()); const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom); const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom); const isElectron = typeof window !== "undefined" && !!window.electronAPI; + useEffect(() => { + if (!electronAPI?.getAgentFilesystemSettings) return; + let mounted = true; + electronAPI + .getAgentFilesystemSettings() + .then((settings: FilesystemSettings) => { + if (!mounted) return; + setFilesystemSettings(settings); + }) + .catch(() => { + if (!mounted) return; + setFilesystemSettings({ + mode: "cloud", + localRootPaths: [], + updatedAt: new Date().toISOString(), + }); + }); + return () => { + mounted = false; + }; + }, [electronAPI]); + + const hasLocalFilesystemTrust = useCallback(() => { + try { + return window.localStorage.getItem(LOCAL_FILESYSTEM_TRUST_KEY) === "true"; + } catch { + return false; + } + }, []); + + const localRootPaths = filesystemSettings?.localRootPaths ?? []; + const canAddMoreLocalRoots = localRootPaths.length < MAX_LOCAL_FILESYSTEM_ROOTS; + + const applyLocalRootPath = useCallback( + async (path: string) => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const nextLocalRootPaths = [...localRootPaths, path] + .filter((rootPath, index, allPaths) => allPaths.indexOf(rootPath) === index) + .slice(0, MAX_LOCAL_FILESYSTEM_ROOTS); + if (nextLocalRootPaths.length === localRootPaths.length) return; + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPaths: nextLocalRootPaths, + }); + setFilesystemSettings(updated); + }, + [electronAPI, localRootPaths] + ); + + const runPickLocalRoot = useCallback(async () => { + if (!electronAPI?.pickAgentFilesystemRoot) return; + const picked = await electronAPI.pickAgentFilesystemRoot(); + if (!picked) return; + await applyLocalRootPath(picked); + }, [applyLocalRootPath, electronAPI]); + + const handlePickFilesystemRoot = useCallback(async () => { + if (!canAddMoreLocalRoots) return; + if (hasLocalFilesystemTrust()) { + await runPickLocalRoot(); + return; + } + if (!electronAPI?.pickAgentFilesystemRoot) return; + const picked = await electronAPI.pickAgentFilesystemRoot(); + if (!picked) return; + setPendingLocalPath(picked); + setLocalTrustDialogOpen(true); + }, [canAddMoreLocalRoots, electronAPI, hasLocalFilesystemTrust, runPickLocalRoot]); + + const handleRemoveFilesystemRoot = useCallback( + async (rootPathToRemove: string) => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove), + }); + setFilesystemSettings(updated); + }, + [electronAPI, localRootPaths] + ); + + const handleClearFilesystemRoots = useCallback(async () => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPaths: [], + }); + setFilesystemSettings(updated); + }, [electronAPI]); + + const handleFilesystemTabChange = useCallback( + async (tab: "cloud" | "local") => { + if (!electronAPI?.setAgentFilesystemSettings) return; + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: tab === "cloud" ? "cloud" : "desktop_local_folder", + }); + setFilesystemSettings(updated); + }, + [electronAPI] + ); + // AI File Sort state const { data: searchSpaces, refetch: refetchSearchSpaces } = useAtomValue(searchSpacesAtom); const activeSearchSpace = useMemo( @@ -196,7 +330,7 @@ function AuthenticatedDocumentsSidebar({ if (!electronAPI?.getWatchedFolders) return; const api = electronAPI; - const folders = await api.getWatchedFolders(); + const folders = (await api.getWatchedFolders()) as WatchedFolderEntry[]; if (folders.length === 0) { try { @@ -214,9 +348,11 @@ function AuthenticatedDocumentsSidebar({ active: true, }); } - const recovered = await api.getWatchedFolders(); + const recovered = (await api.getWatchedFolders()) as WatchedFolderEntry[]; const ids = new Set( - recovered.filter((f) => f.rootFolderId != null).map((f) => f.rootFolderId as number) + recovered + .filter((f: WatchedFolderEntry) => f.rootFolderId != null) + .map((f: WatchedFolderEntry) => f.rootFolderId as number) ); setWatchedFolderIds(ids); return; @@ -226,7 +362,9 @@ function AuthenticatedDocumentsSidebar({ } const ids = new Set( - folders.filter((f) => f.rootFolderId != null).map((f) => f.rootFolderId as number) + folders + .filter((f: WatchedFolderEntry) => f.rootFolderId != null) + .map((f: WatchedFolderEntry) => f.rootFolderId as number) ); setWatchedFolderIds(ids); }, [searchSpaceId, electronAPI]); @@ -375,8 +513,8 @@ function AuthenticatedDocumentsSidebar({ async (folder: FolderDisplay) => { if (!electronAPI) return; - const watchedFolders = await electronAPI.getWatchedFolders(); - const matched = watchedFolders.find((wf) => wf.rootFolderId === folder.id); + const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; + const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id); if (!matched) { toast.error("This folder is not being watched"); return; @@ -405,8 +543,8 @@ function AuthenticatedDocumentsSidebar({ async (folder: FolderDisplay) => { if (!electronAPI) return; - const watchedFolders = await electronAPI.getWatchedFolders(); - const matched = watchedFolders.find((wf) => wf.rootFolderId === folder.id); + const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; + const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id); if (!matched) { toast.error("This folder is not being watched"); return; @@ -438,8 +576,10 @@ function AuthenticatedDocumentsSidebar({ if (!confirm(`Delete folder "${folder.name}" and all its contents?`)) return; try { if (electronAPI) { - const watchedFolders = await electronAPI.getWatchedFolders(); - const matched = watchedFolders.find((wf) => wf.rootFolderId === folder.id); + const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; + const matched = watchedFolders.find( + (wf: WatchedFolderEntry) => wf.rootFolderId === folder.id + ); if (matched) { await electronAPI.removeWatchedFolder(matched.path); } @@ -836,59 +976,11 @@ function AuthenticatedDocumentsSidebar({ return () => document.removeEventListener("keydown", handleEscape); }, [open, onOpenChange, isMobile, setRightPanelCollapsed]); - const documentsContent = ( - <> -
-
-
- {isMobile && ( - - )} -

{t("title") || "Documents"}

-
-
- {!isMobile && onDockedChange && ( - - - - - - {isDocked ? "Collapse panel" : "Expand panel"} - - - )} - {headerAction} -
-
-
+ const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings; + const currentFilesystemTab = filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; + const cloudContent = ( + <> {/* Connected tools strip */}
+ + ); + + const localContent = ( +
+
+ {localRootPaths.length > 0 ? ( + <> + {localRootPaths.map((rootPath) => ( +
+ + {getFolderDisplayName(rootPath)} + +
+ ))} + + + + ) : ( + + )} +
+
+
+
+
+ setLocalSearch(e.target.value)} + placeholder="Search local files" + type="text" + aria-label="Search local files" + /> + {Boolean(localSearch) && ( + + )} +
+
+ { + openEditorPanel({ + kind: "local_file", + localFilePath, + title: localFilePath.split("/").pop() || localFilePath, + searchSpaceId, + }); + }} + /> +
+ ); + + const documentsContent = ( + <> +
+
+
+ {isMobile && ( + + )} +

{t("title") || "Documents"}

+ {showFilesystemTabs && ( + { + void handleFilesystemTabChange(value === "local" ? "local" : "cloud"); + }} + > + + + + Cloud + + + + Local + + + + )} +
+
+ {!isMobile && onDockedChange && ( + + + + + + {isDocked ? "Collapse panel" : "Expand panel"} + + + )} + {headerAction} +
+
+
+ {showFilesystemTabs ? ( + { + void handleFilesystemTabChange(value === "local" ? "local" : "cloud"); + }} + className="flex min-h-0 flex-1 flex-col" + > + + {cloudContent} + + + {localContent} + + + ) : ( + cloudContent + )} {versionDocId !== null && ( )} + { + setLocalTrustDialogOpen(nextOpen); + if (!nextOpen) setPendingLocalPath(null); + }} + > + + + Trust this workspace? + + Local mode can read and edit files inside the folders you select. Continue only if + you trust this workspace and its contents. + + {pendingLocalPath && ( + + Folder path: {pendingLocalPath} + + )} + + + Cancel + { + try { + window.localStorage.setItem(LOCAL_FILESYSTEM_TRUST_KEY, "true"); + } catch {} + setLocalTrustDialogOpen(false); + const path = pendingLocalPath; + setPendingLocalPath(null); + if (path) { + await applyLocalRootPath(path); + } else { + await runPickLocalRoot(); + } + }} + > + I trust this workspace + + + + void; +} + +interface LocalFolderFileEntry { + relativePath: string; + fullPath: string; + size: number; + mtimeMs: number; +} + +type RootLoadState = { + loading: boolean; + error: string | null; + files: LocalFolderFileEntry[]; +}; + +interface LocalFolderNode { + key: string; + name: string; + folders: Map; + files: LocalFolderFileEntry[]; +} + +const getFolderDisplayName = (rootPath: string): string => + rootPath.split(/[\\/]/).at(-1) || rootPath; + +function createFolderNode(key: string, name: string): LocalFolderNode { + return { + key, + name, + folders: new Map(), + files: [], + }; +} + +function getFileName(pathValue: string): string { + return pathValue.split(/[\\/]/).at(-1) || pathValue; +} + +export function LocalFilesystemBrowser({ + rootPaths, + searchSpaceId, + searchQuery, + onOpenFile, +}: LocalFilesystemBrowserProps) { + const electronAPI = useElectronAPI(); + const [rootStateMap, setRootStateMap] = useState>({}); + const [expandedFolderKeys, setExpandedFolderKeys] = useState>(new Set()); + const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); + + useEffect(() => { + setExpandedFolderKeys((prev) => { + const next = new Set(prev); + for (const rootPath of rootPaths) { + next.add(rootPath); + } + return next; + }); + }, [rootPaths]); + + useEffect(() => { + if (!electronAPI?.listFolderFiles) return; + let cancelled = false; + + for (const rootPath of rootPaths) { + setRootStateMap((prev) => ({ + ...prev, + [rootPath]: { + loading: true, + error: null, + files: prev[rootPath]?.files ?? [], + }, + })); + } + + void Promise.all( + rootPaths.map(async (rootPath) => { + try { + const files = (await electronAPI.listFolderFiles({ + path: rootPath, + name: getFolderDisplayName(rootPath), + excludePatterns: DEFAULT_EXCLUDE_PATTERNS, + fileExtensions: supportedExtensions, + rootFolderId: null, + searchSpaceId, + active: true, + })) as LocalFolderFileEntry[]; + if (cancelled) return; + setRootStateMap((prev) => ({ + ...prev, + [rootPath]: { + loading: false, + error: null, + files, + }, + })); + } catch (error) { + if (cancelled) return; + setRootStateMap((prev) => ({ + ...prev, + [rootPath]: { + loading: false, + error: error instanceof Error ? error.message : "Failed to read folder", + files: [], + }, + })); + } + }) + ); + + return () => { + cancelled = true; + }; + }, [electronAPI, rootPaths, searchSpaceId, supportedExtensions]); + + const treeByRoot = useMemo(() => { + const query = searchQuery?.trim().toLowerCase() ?? ""; + const hasQuery = query.length > 0; + + return rootPaths.map((rootPath) => { + const rootNode = createFolderNode(rootPath, getFolderDisplayName(rootPath)); + const allFiles = rootStateMap[rootPath]?.files ?? []; + const files = hasQuery + ? allFiles.filter((file) => { + const relativePath = file.relativePath.toLowerCase(); + const fileName = getFileName(file.relativePath).toLowerCase(); + return relativePath.includes(query) || fileName.includes(query); + }) + : allFiles; + for (const file of files) { + const parts = file.relativePath.split(/[\\/]/).filter(Boolean); + let cursor = rootNode; + for (let i = 0; i < parts.length - 1; i++) { + const part = parts[i]; + const folderKey = `${cursor.key}/${part}`; + if (!cursor.folders.has(part)) { + cursor.folders.set(part, createFolderNode(folderKey, part)); + } + cursor = cursor.folders.get(part) as LocalFolderNode; + } + cursor.files.push(file); + } + return { rootPath, rootNode, matchCount: files.length, totalCount: allFiles.length }; + }); + }, [rootPaths, rootStateMap, searchQuery]); + + const toggleFolder = useCallback((folderKey: string) => { + setExpandedFolderKeys((prev) => { + const next = new Set(prev); + if (next.has(folderKey)) { + next.delete(folderKey); + } else { + next.add(folderKey); + } + return next; + }); + }, []); + + const renderFolder = useCallback( + (folder: LocalFolderNode, depth: number) => { + const isExpanded = expandedFolderKeys.has(folder.key); + const childFolders = Array.from(folder.folders.values()).sort((a, b) => + a.name.localeCompare(b.name) + ); + const files = [...folder.files].sort((a, b) => a.relativePath.localeCompare(b.relativePath)); + return ( +
+ + {isExpanded && ( + <> + {childFolders.map((childFolder) => renderFolder(childFolder, depth + 1))} + {files.map((file) => ( + + ))} + + )} +
+ ); + }, + [expandedFolderKeys, onOpenFile, toggleFolder] + ); + + if (rootPaths.length === 0) { + return ( +
+

No local folder selected

+

+ Add a local folder above to browse files in desktop mode. +

+
+ ); + } + + return ( +
+ {treeByRoot.map(({ rootPath, rootNode, matchCount, totalCount }) => { + const state = rootStateMap[rootPath]; + if (!state || state.loading) { + return ( +
+ + Loading {getFolderDisplayName(rootPath)}... +
+ ); + } + if (state.error) { + return ( +
+

Failed to load local folder

+

{state.error}

+
+ ); + } + const isEmpty = totalCount === 0; + return ( +
+ {renderFolder(rootNode, 0)} + {isEmpty && ( +
+ No supported files found in this folder. +
+ )} + {!isEmpty && matchCount === 0 && searchQuery && ( +
+ No matching files in this folder. +
+ )} +
+ ); + })} +
+ ); +} From d1c14160e3ac2b4025357fb14571b344e27025fd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 04:42:24 +0530 Subject: [PATCH 142/299] feat(sidebar): enhance DocumentsSidebar with dropdown menu for local folder management and improve UI interactions --- .../layout/ui/sidebar/DocumentsSidebar.tsx | 150 +++++++++++------- .../ui/sidebar/LocalFilesystemBrowser.tsx | 10 -- 2 files changed, 89 insertions(+), 71 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 5c955a53e..dbe2f16e4 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -7,11 +7,13 @@ import { ChevronRight, FileText, Folder, + FolderPlus, FolderClock, - Globe, + Laptop, Lock, Paperclip, Search, + Server, Trash2, Unplug, Upload, @@ -61,8 +63,17 @@ import { } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; import { Input } from "@/components/ui/input"; +import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; @@ -1135,76 +1146,93 @@ function AuthenticatedDocumentsSidebar({ ); const localContent = ( -
-
- {localRootPaths.length > 0 ? ( - <> - {localRootPaths.map((rootPath) => ( -
- - {getFolderDisplayName(rootPath)} +
+
+
+ {localRootPaths.length > 0 ? ( + + -
- ))} - - - - ) : ( -
+ )} + + - )} + + +
-
setLocalSearch(e.target.value)} placeholder="Search local files" @@ -1214,14 +1242,14 @@ function AuthenticatedDocumentsSidebar({ {Boolean(localSearch) && ( )}
@@ -1266,21 +1294,21 @@ function AuthenticatedDocumentsSidebar({ void handleFilesystemTabChange(value === "local" ? "local" : "cloud"); }} > - + - + Cloud - + Local diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 544280116..7aebf4695 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -61,16 +61,6 @@ export function LocalFilesystemBrowser({ const [expandedFolderKeys, setExpandedFolderKeys] = useState>(new Set()); const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); - useEffect(() => { - setExpandedFolderKeys((prev) => { - const next = new Set(prev); - for (const rootPath of rootPaths) { - next.add(rootPath); - } - return next; - }); - }, [rootPaths]); - useEffect(() => { if (!electronAPI?.listFolderFiles) return; let cancelled = false; From ce71897286c4f4772928f9155f033715c2690732 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 04:54:48 +0530 Subject: [PATCH 143/299] refactor(hotkeys): simplify hotkey display logic and replace icon representation with text in DesktopShortcutsContent and login page --- .../components/DesktopShortcutsContent.tsx | 45 ++++++------------- surfsense_web/app/desktop/login/page.tsx | 45 ++++++------------- 2 files changed, 26 insertions(+), 64 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx index f4981b8f0..6207457c4 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx @@ -1,10 +1,11 @@ "use client"; -import { ArrowBigUp, BrainCog, Command, Option, Rocket, RotateCcw, Zap } from "lucide-react"; +import { BrainCog, Rocket, RotateCcw, Zap } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; import { Button } from "@/components/ui/button"; +import { ShortcutKbd } from "@/components/ui/shortcut-kbd"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; @@ -17,24 +18,20 @@ const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; icon: React.ElementT { key: "autocomplete", label: "Extreme Assist", icon: BrainCog }, ]; -type ShortcutToken = - | { kind: "text"; value: string } - | { kind: "icon"; value: "command" | "option" | "shift" }; - -function acceleratorToTokens(accel: string, isMac: boolean): ShortcutToken[] { +function acceleratorToKeys(accel: string, isMac: boolean): string[] { if (!accel) return []; return accel.split("+").map((part) => { if (part === "CommandOrControl") { - return isMac ? { kind: "icon", value: "command" as const } : { kind: "text", value: "Ctrl" }; + return isMac ? "⌘" : "Ctrl"; } if (part === "Alt") { - return isMac ? { kind: "icon", value: "option" as const } : { kind: "text", value: "Alt" }; + return isMac ? "⌥" : "Alt"; } if (part === "Shift") { - return isMac ? { kind: "icon", value: "shift" as const } : { kind: "text", value: "Shift" }; + return isMac ? "⇧" : "Shift"; } - if (part === "Space") return { kind: "text", value: "Space" }; - return { kind: "text", value: part.length === 1 ? part.toUpperCase() : part }; + if (part === "Space") return "Space"; + return part.length === 1 ? part.toUpperCase() : part; }); } @@ -58,7 +55,7 @@ function HotkeyRow({ const [recording, setRecording] = useState(false); const inputRef = useRef(null); const isDefault = value === defaultValue; - const displayTokens = useMemo(() => acceleratorToTokens(value, isMac), [value, isMac]); + const displayKeys = useMemo(() => acceleratorToKeys(value, isMac), [value, isMac]); const handleKeyDown = useCallback( (e: React.KeyboardEvent) => { @@ -103,13 +100,14 @@ function HotkeyRow({
diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 6d5e2abd4..451143949 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -2,7 +2,7 @@ import { IconBrandGoogleFilled } from "@tabler/icons-react"; import { useAtom } from "jotai"; -import { ArrowBigUp, BrainCog, Command, Eye, EyeOff, Option, Rocket, RotateCcw, Zap } from "lucide-react"; +import { BrainCog, Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react"; import Image from "next/image"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; @@ -13,6 +13,7 @@ import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Separator } from "@/components/ui/separator"; +import { ShortcutKbd } from "@/components/ui/shortcut-kbd"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; @@ -23,10 +24,6 @@ const isGoogleAuth = AUTH_TYPE === "GOOGLE"; type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; -type ShortcutToken = - | { kind: "text"; value: string } - | { kind: "icon"; value: "command" | "option" | "shift" }; - const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; icon: React.ElementType }> = [ { key: "generalAssist", @@ -48,20 +45,20 @@ const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; }, ]; -function acceleratorToTokens(accel: string, isMac: boolean): ShortcutToken[] { +function acceleratorToKeys(accel: string, isMac: boolean): string[] { if (!accel) return []; return accel.split("+").map((part) => { if (part === "CommandOrControl") { - return isMac ? { kind: "icon", value: "command" as const } : { kind: "text", value: "Ctrl" }; + return isMac ? "⌘" : "Ctrl"; } if (part === "Alt") { - return isMac ? { kind: "icon", value: "option" as const } : { kind: "text", value: "Alt" }; + return isMac ? "⌥" : "Alt"; } if (part === "Shift") { - return isMac ? { kind: "icon", value: "shift" as const } : { kind: "text", value: "Shift" }; + return isMac ? "⇧" : "Shift"; } - if (part === "Space") return { kind: "text", value: "Space" }; - return { kind: "text", value: part.length === 1 ? part.toUpperCase() : part }; + if (part === "Space") return "Space"; + return part.length === 1 ? part.toUpperCase() : part; }); } @@ -87,7 +84,7 @@ function HotkeyRow({ const [recording, setRecording] = useState(false); const inputRef = useRef(null); const isDefault = value === defaultValue; - const displayTokens = useMemo(() => acceleratorToTokens(value, isMac), [value, isMac]); + const displayKeys = useMemo(() => acceleratorToKeys(value, isMac), [value, isMac]); const handleKeyDown = useCallback( (e: React.KeyboardEvent) => { @@ -135,36 +132,20 @@ function HotkeyRow({
From a7a758f26edc04be3e3a6ec3a3cde207f8046bef Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 05:03:23 +0530 Subject: [PATCH 144/299] feat(filesystem): add getAgentFilesystemMounts API and integrate with LocalFilesystemBrowser for improved mount management --- surfsense_desktop/src/ipc/channels.ts | 1 + surfsense_desktop/src/ipc/handlers.ts | 5 ++ .../src/modules/agent-filesystem.ts | 7 ++- surfsense_desktop/src/preload.ts | 2 + .../ui/sidebar/LocalFilesystemBrowser.tsx | 61 +++++++++++++++++-- surfsense_web/types/window.d.ts | 6 ++ 6 files changed, 77 insertions(+), 5 deletions(-) diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index 5cf6e9001..ccd166899 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -55,6 +55,7 @@ export const IPC_CHANNELS = { ANALYTICS_GET_CONTEXT: 'analytics:get-context', // Agent filesystem mode AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings', + AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts', AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings', AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root', } as const; diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 247d171f5..54882f4ee 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -39,6 +39,7 @@ import { import { readAgentLocalFileText, writeAgentLocalFileText, + getAgentFilesystemMounts, getAgentFilesystemSettings, pickAgentFilesystemRoot, setAgentFilesystemSettings, @@ -226,6 +227,10 @@ export function registerIpcHandlers(): void { getAgentFilesystemSettings() ); + ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, () => + getAgentFilesystemMounts() + ); + ipcMain.handle( IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, (_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) => diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index 2bf0101d6..f00c185f8 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -122,7 +122,7 @@ function toVirtualPath(rootPath: string, absolutePath: string): string { return `/${rel.replace(/\\/g, "/")}`; } -type LocalRootMount = { +export type LocalRootMount = { mount: string; rootPath: string; }; @@ -145,6 +145,11 @@ function buildRootMounts(rootPaths: string[]): LocalRootMount[] { return mounts; } +export async function getAgentFilesystemMounts(): Promise { + const rootPaths = await resolveCurrentRootPaths(); + return buildRootMounts(rootPaths); +} + function parseMountedVirtualPath(virtualPath: string): { mount: string; subPath: string; diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index f7aaf9633..9c538f691 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -108,6 +108,8 @@ contextBridge.exposeInMainWorld('electronAPI', { // Agent filesystem mode getAgentFilesystemSettings: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS), + getAgentFilesystemMounts: () => + ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS), setAgentFilesystemSettings: (settings: { mode?: "cloud" | "desktop_local_folder"; localRootPaths?: string[] | null; diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 7aebf4695..5b08f2e37 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -34,6 +34,11 @@ interface LocalFolderNode { files: LocalFolderFileEntry[]; } +type LocalRootMount = { + mount: string; + rootPath: string; +}; + const getFolderDisplayName = (rootPath: string): string => rootPath.split(/[\\/]/).at(-1) || rootPath; @@ -50,6 +55,20 @@ function getFileName(pathValue: string): string { return pathValue.split(/[\\/]/).at(-1) || pathValue; } +function toVirtualPath(relativePath: string): string { + const normalized = relativePath.replace(/\\/g, "/").replace(/^\/+/, ""); + return `/${normalized}`; +} + +function normalizeRootPathForLookup(rootPath: string, isWindows: boolean): string { + const normalized = rootPath.replace(/\\/g, "/").replace(/\/+$/, ""); + return isWindows ? normalized.toLowerCase() : normalized; +} + +function toMountedVirtualPath(mount: string, relativePath: string): string { + return `/${mount}${toVirtualPath(relativePath)}`; +} + export function LocalFilesystemBrowser({ rootPaths, searchSpaceId, @@ -59,7 +78,9 @@ export function LocalFilesystemBrowser({ const electronAPI = useElectronAPI(); const [rootStateMap, setRootStateMap] = useState>({}); const [expandedFolderKeys, setExpandedFolderKeys] = useState>(new Set()); + const [mountByRootKey, setMountByRootKey] = useState>(new Map()); const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); + const isWindowsPlatform = electronAPI?.versions.platform === "win32"; useEffect(() => { if (!electronAPI?.listFolderFiles) return; @@ -116,6 +137,31 @@ export function LocalFilesystemBrowser({ }; }, [electronAPI, rootPaths, searchSpaceId, supportedExtensions]); + useEffect(() => { + if (!electronAPI?.getAgentFilesystemMounts) { + setMountByRootKey(new Map()); + return; + } + let cancelled = false; + void electronAPI + .getAgentFilesystemMounts() + .then((mounts: LocalRootMount[]) => { + if (cancelled) return; + const next = new Map(); + for (const entry of mounts) { + next.set(normalizeRootPathForLookup(entry.rootPath, isWindowsPlatform), entry.mount); + } + setMountByRootKey(next); + }) + .catch(() => { + if (cancelled) return; + setMountByRootKey(new Map()); + }); + return () => { + cancelled = true; + }; + }, [electronAPI, isWindowsPlatform, rootPaths]); + const treeByRoot = useMemo(() => { const query = searchQuery?.trim().toLowerCase() ?? ""; const hasQuery = query.length > 0; @@ -160,7 +206,7 @@ export function LocalFilesystemBrowser({ }, []); const renderFolder = useCallback( - (folder: LocalFolderNode, depth: number) => { + (folder: LocalFolderNode, depth: number, mount: string) => { const isExpanded = expandedFolderKeys.has(folder.key); const childFolders = Array.from(folder.folders.values()).sort((a, b) => a.name.localeCompare(b.name) @@ -185,12 +231,12 @@ export function LocalFilesystemBrowser({ {isExpanded && ( <> - {childFolders.map((childFolder) => renderFolder(childFolder, depth + 1))} + {childFolders.map((childFolder) => renderFolder(childFolder, depth + 1, mount))} {files.map((file) => ( - - {Math.round(scale * 100)}% - - + )} diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index c7a8509ed..ede63d902 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue, useSetAtom } from "jotai"; -import { Check, ChevronDownIcon, Copy, Pencil, XIcon } from "lucide-react"; +import { Check, ChevronDownIcon, Copy, Download, Pencil, XIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; @@ -309,6 +309,7 @@ export function ReportPanelContent({ const isResume = reportContent?.content_type === "typst"; const showReportEditingTier = !isResume; const hasUnsavedChanges = editedMarkdown !== null; + const showDesktopHeader = !!onClose; const handleCancelEditing = useCallback(() => { setEditedMarkdown(null); @@ -316,153 +317,177 @@ export function ReportPanelContent({ setIsEditing(false); }, []); + const exportButton = !isEditing && ( + <> + {isResume ? ( + + ) : ( + + + + + + + + + )} + + ); + + const versionSwitcher = !isEditing && versions.length > 1 && ( + + + + + + {versions.map((v, i) => ( + setActiveReportId(v.id)} + className={v.id === activeReportId ? "bg-accent font-medium" : ""} + > + Version {i + 1} + + ))} + + + ); + + const copyButton = !isEditing && showReportEditingTier && ( + + ); + + const editingActions = showReportEditingTier && + !isReadOnly && + (isEditing ? ( + <> + + + + ) : ( + + )); + return ( <> - {/* Action bar — always visible; buttons are disabled while loading */} -
-
- {/* Export — plain button for resume (typst), dropdown for others */} - {reportContent?.content_type === "typst" ? ( - - ) : ( - - - - - - - - - )} - - {/* Version switcher — only shown when multiple versions exist */} - {versions.length > 1 && ( - - - - - - {versions.map((v, i) => ( - setActiveReportId(v.id)} - className={v.id === activeReportId ? "bg-accent font-medium" : ""} - > - Version {i + 1} - - ))} - - - )} -
- {onClose && ( - - )} -
- - {showReportEditingTier && ( -
-
-

- {reportContent?.title || title} -

-
-
- {!isEditing && ( - )} - {!isReadOnly && - (isEditing ? ( - <> - - - - ) : ( - - ))}
-
+ + {!isResume && ( +
+
+

+ {reportContent?.title || title} +

+
+
+ {versionSwitcher} + {exportButton} + {copyButton} + {editingActions} +
+
+ )} + + ) : ( + !isResume && ( +
+
+

{reportContent?.title || title}

+
+
+ {versionSwitcher} + {exportButton} + {copyButton} + {editingActions} +
+
+ ) )} {/* Report content — skeleton/error/viewer/editor shown only in this area */} @@ -480,6 +505,12 @@ export function ReportPanelContent({ + {versionSwitcher} + {exportButton} + + } /> ) : reportContent.content ? ( isReadOnly ? ( diff --git a/surfsense_web/components/tool-ui/generate-report.tsx b/surfsense_web/components/tool-ui/generate-report.tsx index 32f97b6a4..912028596 100644 --- a/surfsense_web/components/tool-ui/generate-report.tsx +++ b/surfsense_web/components/tool-ui/generate-report.tsx @@ -137,10 +137,9 @@ function ReportCard({ const autoOpenedRef = useRef(false); const [metadata, setMetadata] = useState<{ title: string; - wordCount: number | null; versionLabel: string | null; content: string | null; - }>({ title, wordCount: wordCount ?? null, versionLabel: null, content: null }); + }>({ title, versionLabel: null, content: null }); const [isLoading, setIsLoading] = useState(true); const [error, setError] = useState(null); @@ -169,10 +168,8 @@ function ReportCard({ } } const resolvedTitle = parsed.data.title || title; - const resolvedWordCount = parsed.data.report_metadata?.word_count ?? wordCount ?? null; setMetadata({ title: resolvedTitle, - wordCount: resolvedWordCount, versionLabel, content: parsed.data.content ?? null, }); @@ -182,7 +179,7 @@ function ReportCard({ openPanel({ reportId, title: resolvedTitle, - wordCount: resolvedWordCount ?? undefined, + wordCount: parsed.data.report_metadata?.word_count ?? wordCount ?? undefined, shareToken, }); } @@ -210,7 +207,6 @@ function ReportCard({ openPanel({ reportId, title: metadata.title, - wordCount: metadata.wordCount ?? undefined, shareToken, }); }; @@ -233,10 +229,8 @@ function ReportCard({ ) : ( <> - {metadata.wordCount != null && `${metadata.wordCount.toLocaleString()} words`} - {metadata.wordCount != null && metadata.versionLabel && ( - - )} + Markdown + {metadata.versionLabel && } {metadata.versionLabel} )} diff --git a/surfsense_web/components/tool-ui/generate-resume.tsx b/surfsense_web/components/tool-ui/generate-resume.tsx index 1290a70ea..4e9d06fbb 100644 --- a/surfsense_web/components/tool-ui/generate-resume.tsx +++ b/surfsense_web/components/tool-ui/generate-resume.tsx @@ -2,6 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useAtomValue, useSetAtom } from "jotai"; +import { Dot } from "lucide-react"; import { useParams, usePathname } from "next/navigation"; import * as pdfjsLib from "pdfjs-dist"; import { useCallback, useEffect, useRef, useState } from "react"; @@ -9,6 +10,7 @@ import { z } from "zod"; import { openReportPanelAtom, reportPanelAtom } from "@/atoms/chat/report-panel.atom"; import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { useMediaQuery } from "@/hooks/use-media-query"; +import { baseApiService } from "@/lib/apis/base-api.service"; import { getAuthHeaders } from "@/lib/auth-utils"; pdfjsLib.GlobalWorkerOptions.workerSrc = new URL( @@ -32,6 +34,18 @@ const GenerateResumeResultSchema = z.object({ error: z.string().nullish(), }); +const ResumeVersionsResponseSchema = z.object({ + id: z.number(), + versions: z + .array( + z.object({ + id: z.number(), + created_at: z.string().nullish(), + }) + ) + .nullish(), +}); + type GenerateResumeArgs = z.infer; type GenerateResumeResult = z.infer; @@ -201,6 +215,7 @@ function ResumeCard({ const autoOpenedRef = useRef(false); const [pdfUrl, setPdfUrl] = useState(null); const [thumbState, setThumbState] = useState<"loading" | "ready" | "error">("loading"); + const [versionLabel, setVersionLabel] = useState(null); useEffect(() => { const previewPath = shareToken @@ -219,6 +234,35 @@ function ResumeCard({ } }, [reportId, title, shareToken, autoOpen, isDesktop, openPanel]); + useEffect(() => { + let cancelled = false; + const fetchVersions = async () => { + try { + const url = shareToken + ? `/api/v1/public/${shareToken}/reports/${reportId}/content` + : `/api/v1/reports/${reportId}/content`; + const rawData = await baseApiService.get(url); + if (cancelled) return; + const parsed = ResumeVersionsResponseSchema.safeParse(rawData); + if (parsed.success) { + const versions = parsed.data.versions; + if (versions && versions.length > 1) { + const idx = versions.findIndex((v) => v.id === reportId); + if (idx >= 0) { + setVersionLabel(`version ${idx + 1}`); + } + } + } + } catch { + // silently ignore — version label is non-critical + } + }; + fetchVersions(); + return () => { + cancelled = true; + }; + }, [reportId, shareToken]); + const onThumbLoad = useCallback(() => setThumbState("ready"), []); const onThumbError = useCallback(() => setThumbState("error"), []); @@ -243,8 +287,12 @@ function ResumeCard({ className="w-full text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none cursor-pointer select-none" >
-

{title}

-

PDF

+

{title}

+

+ PDF + {versionLabel && } + {versionLabel} +

From 3f97b77ab64976118cb6c8881178d1c6d6baddb1 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 24 Apr 2026 19:17:43 +0200 Subject: [PATCH 155/299] Support multimodal chat with pending screen images on web --- .../[search_space_id]/client-layout.tsx | 10 ++ .../new-chat/[[...chat_id]]/page.tsx | 73 ++++++----- .../atoms/chat/pending-user-images.atom.ts | 3 + .../components/assistant-ui/thread.tsx | 58 ++++++++- .../lib/chat/display-media-capture.ts | 120 ++++++++++++++++++ surfsense_web/lib/chat/user-turn-api-parts.ts | 57 +++++++++ 6 files changed, 285 insertions(+), 36 deletions(-) create mode 100644 surfsense_web/atoms/chat/pending-user-images.atom.ts create mode 100644 surfsense_web/lib/chat/display-media-capture.ts create mode 100644 surfsense_web/lib/chat/user-turn-api-parts.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index eceb46231..d95aab6e8 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -6,6 +6,7 @@ import { useTranslations } from "next-intl"; import type React from "react"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { myAccessAtom } from "@/atoms/members/members-query.atoms"; import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; import { @@ -33,6 +34,7 @@ export function DashboardClientLayout({ const pathname = usePathname(); const { search_space_id } = useParams(); const setActiveSearchSpaceIdState = useSetAtom(activeSearchSpaceIdAtom); + const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); const { data: preferences = {}, @@ -142,6 +144,14 @@ export function DashboardClientLayout({ const electronAPI = useElectronAPI(); + useEffect(() => { + if (!electronAPI?.onChatScreenCapture) return; + return electronAPI.onChatScreenCapture((dataUrl: string) => { + if (typeof dataUrl !== "string" || !dataUrl.startsWith("data:image/")) return; + setPendingUserImageUrls((prev) => [...prev, dataUrl]); + }); + }, [electronAPI, setPendingUserImageUrls]); + useEffect(() => { const activeSeacrhSpaceId = typeof search_space_id === "string" diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 62332d2c4..fe23cb2c7 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -26,6 +26,7 @@ import { messageDocumentsMapAtom, sidebarSelectedDocumentsAtom, } from "@/atoms/chat/mentioned-documents.atom"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { clearPlanOwnerRegistry, // extractWriteTodosFromContent, @@ -45,8 +46,8 @@ import { } from "@/components/assistant-ui/token-usage-context"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useMessagesSync } from "@/hooks/use-messages-sync"; -import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { @@ -76,6 +77,7 @@ import { type ThreadListResponse, type ThreadRecord, } from "@/lib/chat/thread-persistence"; +import { extractUserTurnForNewChatApi } from "@/lib/chat/user-turn-api-parts"; import { NotFoundError } from "@/lib/error"; import { trackChatCreated, @@ -231,6 +233,8 @@ export default function NewChatPage() { const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom); const removeChatTab = useSetAtom(removeChatTabAtom); const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); + const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); + const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); @@ -494,18 +498,13 @@ export default function NewChatPage() { abortControllerRef.current = null; } - // Extract user query text from content parts - let userQuery = ""; - for (const part of message.content) { - if (part.type === "text") { - userQuery += part.text; - } - } + const urlsSnapshot = [...pendingUserImageUrls]; + setPendingUserImageUrls([]); + const { userQuery, userImages } = extractUserTurnForNewChatApi(message, urlsSnapshot); - if (!userQuery.trim()) return; + if (!userQuery.trim() && userImages.length === 0) return; - // Check if podcast is already generating - if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) { + if (userQuery.trim() && isPodcastGenerating() && looksLikePodcastRequest(userQuery)) { toast.warning("A podcast is already being generated."); return; } @@ -560,10 +559,27 @@ export default function NewChatPage() { } : undefined; + const existingImageUrls = new Set( + message.content + .filter( + (p): p is { type: "image"; image: string } => + typeof p === "object" && + p !== null && + "type" in p && + p.type === "image" && + "image" in p + ) + .map((p) => p.image) + ); + const extraImageParts = urlsSnapshot + .filter((u) => !existingImageUrls.has(u)) + .map((image) => ({ type: "image" as const, image })); + const userDisplayContent = [...message.content, ...extraImageParts]; + const userMessage: ThreadMessageLike = { id: userMsgId, role: "user", - content: message.content, + content: userDisplayContent, createdAt: new Date(), metadata: authorMetadata, }; @@ -571,7 +587,7 @@ export default function NewChatPage() { // Track message sent trackChatMessageSent(searchSpaceId, currentThreadId, { - hasAttachments: false, + hasAttachments: userImages.length > 0, hasMentionedDocuments: mentionedDocumentIds.surfsense_doc_ids.length > 0 || mentionedDocumentIds.document_ids.length > 0, @@ -596,7 +612,7 @@ export default function NewChatPage() { })); } - const persistContent: unknown[] = [...message.content]; + const persistContent: unknown[] = [...userDisplayContent]; if (allMentionedDocs.length > 0) { persistContent.push({ @@ -661,8 +677,7 @@ export default function NewChatPage() { const selection = await getAgentFilesystemSelection(); if ( selection.filesystem_mode === "desktop_local_folder" && - (!selection.local_filesystem_mounts || - selection.local_filesystem_mounts.length === 0) + (!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0) ) { toast.error("Select a local folder before using Local Folder mode."); return; @@ -711,6 +726,7 @@ export default function NewChatPage() { ? mentionedDocumentIds.surfsense_doc_ids : undefined, disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + ...(userImages.length > 0 ? { user_images: userImages } : {}), }), signal: controller.signal, }); @@ -842,14 +858,7 @@ export default function NewChatPage() { }); } else { const tcId = `interrupt-${action.name}`; - addToolCall( - contentPartsState, - toolsWithUI, - tcId, - action.name, - action.args, - true - ); + addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); updateToolCall(contentPartsState, tcId, { result: { __interrupt__: true, ...interruptData }, }); @@ -989,6 +998,9 @@ export default function NewChatPage() { disabledTools, updateChatTabTitle, tokenUsageStore, + pendingUserImageUrls, + setPendingUserImageUrls, + toolsWithUI, ] ); @@ -1189,14 +1201,7 @@ export default function NewChatPage() { }); } else { const tcId = `interrupt-${action.name}`; - addToolCall( - contentPartsState, - toolsWithUI, - tcId, - action.name, - action.args, - true - ); + addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); updateToolCall(contentPartsState, tcId, { result: { __interrupt__: true, @@ -1261,7 +1266,7 @@ export default function NewChatPage() { abortControllerRef.current = null; } }, - [pendingInterrupt, messages, searchSpaceId, tokenUsageStore] + [pendingInterrupt, messages, searchSpaceId, tokenUsageStore, toolsWithUI] ); useEffect(() => { @@ -1588,7 +1593,7 @@ export default function NewChatPage() { abortControllerRef.current = null; } }, - [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore] + [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] ); // Handle editing a message - truncates history and regenerates with new query diff --git a/surfsense_web/atoms/chat/pending-user-images.atom.ts b/surfsense_web/atoms/chat/pending-user-images.atom.ts new file mode 100644 index 000000000..6898e745d --- /dev/null +++ b/surfsense_web/atoms/chat/pending-user-images.atom.ts @@ -0,0 +1,3 @@ +import { atom } from "jotai"; + +export const pendingUserImageDataUrlsAtom = atom([]); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 2ec422fbf..6862662f2 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -16,6 +16,7 @@ import { ChevronUp, Clipboard, Dot, + Camera, Globe, Plus, Settings2, @@ -40,6 +41,7 @@ import { mentionedDocumentsAtom, sidebarSelectedDocumentsAtom, } from "@/atoms/chat/mentioned-documents.atom"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { documentsSidebarOpenAtom } from "@/atoms/documents/ui.atoms"; @@ -89,6 +91,7 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsSync } from "@/hooks/use-comments-sync"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; +import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture"; import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; import { cn } from "@/lib/utils"; @@ -295,6 +298,32 @@ const ConnectToolsBanner: FC<{ isThreadEmpty: boolean }> = ({ isThreadEmpty }) = ); }; +const PendingScreenImageStrip: FC = () => { + const [urls, setUrls] = useAtom(pendingUserImageDataUrlsAtom); + if (urls.length === 0) return null; + return ( +
+ {urls.map((url, index) => ( +
+ {/* biome-ignore lint/performance/noImgElement: data URL thumbnails from capture */} + + +
+ ))} +
+ ); +}; + const ClipboardChip: FC<{ text: string; onDismiss: () => void }> = ({ text, onDismiss }) => { const [expanded, setExpanded] = useState(false); const isLong = text.length > 120; @@ -702,6 +731,7 @@ const Composer: FC = () => {
)}
+ {clipboardInitialText && ( = ({ isBlockedByOtherUser = false }, [] ); + const pendingScreenImages = useAtomValue(pendingUserImageDataUrlsAtom); + const setPendingScreenImages = useSetAtom(pendingUserImageDataUrlsAtom); + const electronAPI = useElectronAPI(); + const isComposerTextEmpty = useAuiState(({ composer }) => { const text = composer.text?.trim() || ""; return text.length === 0; }); - const isComposerEmpty = isComposerTextEmpty && mentionedDocuments.length === 0; + const isComposerEmpty = + isComposerTextEmpty && mentionedDocuments.length === 0 && pendingScreenImages.length === 0; + + const handleScreenCapture = useCallback(async () => { + const url = await captureDisplayToPngDataUrl(); + if (url) setPendingScreenImages((prev) => [...prev, url]); + }, [setPendingScreenImages]); const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); @@ -1201,6 +1241,20 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false
)}
+ {/* Electron: native shortcut → pending images; skip in-webview getDisplayMedia. */} + {!electronAPI && ( + void handleScreenCapture()} + > + + + )} !thread.isRunning}> = ({ isBlockedByOtherUser = false : !hasModelConfigured ? "Please select a model from the header to start chatting" : isComposerEmpty - ? "Enter a message to send" + ? "Enter a message or add a screenshot to send" : "Send message" } side="bottom" diff --git a/surfsense_web/lib/chat/display-media-capture.ts b/surfsense_web/lib/chat/display-media-capture.ts new file mode 100644 index 000000000..c2fb69aae --- /dev/null +++ b/surfsense_web/lib/chat/display-media-capture.ts @@ -0,0 +1,120 @@ +/** `getDisplayMedia` → single PNG frame (data URL). */ +function getImageCaptureCtor(): + | (new ( + track: MediaStreamTrack + ) => { grabFrame: () => Promise }) + | undefined { + if (typeof window === "undefined") return undefined; + const IC = ( + window as unknown as { + ImageCapture?: new (track: MediaStreamTrack) => { grabFrame: () => Promise }; + } + ).ImageCapture; + return typeof IC === "function" ? IC : undefined; +} + +function stopAllTracks(stream: MediaStream): void { + for (const t of stream.getTracks()) { + t.stop(); + } +} + +async function captureTrackToPngDataUrl( + track: MediaStreamTrack, + stream: MediaStream +): Promise { + const ImageCtor = getImageCaptureCtor(); + if (ImageCtor !== undefined) { + try { + const ic = new ImageCtor(track); + const bitmap = await ic.grabFrame(); + try { + const canvas = document.createElement("canvas"); + canvas.width = bitmap.width; + canvas.height = bitmap.height; + const ctx = canvas.getContext("2d"); + if (!ctx) { + stopAllTracks(stream); + return null; + } + ctx.drawImage(bitmap, 0, 0); + stopAllTracks(stream); + return canvas.toDataURL("image/png"); + } finally { + if ("close" in bitmap && typeof bitmap.close === "function") { + bitmap.close(); + } + } + } catch { + /* fall through to
); } diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx index 6207457c4..0b7f330d9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx @@ -1,6 +1,6 @@ "use client"; -import { BrainCog, Rocket, RotateCcw, Zap } from "lucide-react"; +import { Rocket, RotateCcw, Zap } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; @@ -9,13 +9,12 @@ import { ShortcutKbd } from "@/components/ui/shortcut-kbd"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; -type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete"; +type ShortcutKey = "generalAssist" | "quickAsk"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; icon: React.ElementType }> = [ { key: "generalAssist", label: "General Assist", icon: Rocket }, { key: "quickAsk", label: "Quick Assist", icon: Zap }, - { key: "autocomplete", label: "Extreme Assist", icon: BrainCog }, ]; function acceleratorToKeys(accel: string, isMac: boolean): string[] { @@ -111,9 +110,7 @@ function HotkeyRow({ } > {recording ? ( - - Press hotkeys... - + Press hotkeys... ) : ( )} @@ -155,15 +152,14 @@ export function DesktopShortcutsContent() { if (!api) { return (
-

Hotkeys are only available in the SurfSense desktop app.

+

+ Hotkeys are only available in the SurfSense desktop app. +

); } - const updateShortcut = ( - key: "generalAssist" | "quickAsk" | "autocomplete", - accelerator: string - ) => { + const updateShortcut = (key: ShortcutKey, accelerator: string) => { setShortcuts((prev) => { const updated = { ...prev, [key]: accelerator }; api.setShortcuts?.({ [key]: accelerator }).catch(() => { @@ -178,28 +174,26 @@ export function DesktopShortcutsContent() { updateShortcut(key, DEFAULT_SHORTCUTS[key]); }; - return ( - shortcutsLoaded ? ( -
-
- {HOTKEY_ROWS.map((row) => ( - updateShortcut(row.key, accel)} - onReset={() => resetShortcut(row.key)} - /> - ))} -
+ return shortcutsLoaded ? ( +
+
+ {HOTKEY_ROWS.map((row) => ( + updateShortcut(row.key, accel)} + onReset={() => resetShortcut(row.key)} + /> + ))}
- ) : ( -
- -
- ) +
+ ) : ( +
+ +
); } diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 451143949..edb6cffab 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -2,7 +2,7 @@ import { IconBrandGoogleFilled } from "@tabler/icons-react"; import { useAtom } from "jotai"; -import { BrainCog, Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react"; +import { Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react"; import Image from "next/image"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; @@ -21,10 +21,15 @@ import { setBearerToken } from "@/lib/auth-utils"; import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; const isGoogleAuth = AUTH_TYPE === "GOOGLE"; -type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete"; +type ShortcutKey = "generalAssist" | "quickAsk"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; -const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; icon: React.ElementType }> = [ +const HOTKEY_ROWS: Array<{ + key: ShortcutKey; + label: string; + description: string; + icon: React.ElementType; +}> = [ { key: "generalAssist", label: "General Assist", @@ -37,12 +42,6 @@ const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; description: "Select text anywhere, then ask AI to explain, rewrite, or act on it", icon: Zap, }, - { - key: "autocomplete", - label: "Extreme Assist", - description: "AI drafts text using your screen context and knowledge base", - icon: BrainCog, - }, ]; function acceleratorToKeys(accel: string, isMac: boolean): string[] { @@ -182,7 +181,7 @@ export default function DesktopLoginPage() { }, [api]); const updateShortcut = useCallback( - (key: "generalAssist" | "quickAsk" | "autocomplete", accelerator: string) => { + (key: ShortcutKey, accelerator: string) => { setShortcuts((prev) => { const updated = { ...prev, [key]: accelerator }; api?.setShortcuts?.({ [key]: accelerator }).catch(() => { @@ -196,7 +195,7 @@ export default function DesktopLoginPage() { ); const resetShortcut = useCallback( - (key: "generalAssist" | "quickAsk" | "autocomplete") => { + (key: ShortcutKey) => { updateShortcut(key, DEFAULT_SHORTCUTS[key]); }, [updateShortcut] @@ -369,7 +368,9 @@ export default function DesktopLoginPage() { )} diff --git a/surfsense_web/app/desktop/suggestion/layout.tsx b/surfsense_web/app/desktop/suggestion/layout.tsx deleted file mode 100644 index fd8faf099..000000000 --- a/surfsense_web/app/desktop/suggestion/layout.tsx +++ /dev/null @@ -1,9 +0,0 @@ -import "./suggestion.css"; - -export const metadata = { - title: "SurfSense Suggestion", -}; - -export default function SuggestionLayout({ children }: { children: React.ReactNode }) { - return
{children}
; -} diff --git a/surfsense_web/app/desktop/suggestion/page.tsx b/surfsense_web/app/desktop/suggestion/page.tsx deleted file mode 100644 index d30da65f6..000000000 --- a/surfsense_web/app/desktop/suggestion/page.tsx +++ /dev/null @@ -1,384 +0,0 @@ -"use client"; - -import { useCallback, useEffect, useRef, useState } from "react"; -import { useElectronAPI } from "@/hooks/use-platform"; -import { ensureTokensFromElectron, getBearerToken } from "@/lib/auth-utils"; - -type SSEEvent = - | { type: "text-delta"; id: string; delta: string } - | { type: "text-start"; id: string } - | { type: "text-end"; id: string } - | { type: "start"; messageId: string } - | { type: "finish" } - | { type: "error"; errorText: string } - | { - type: "data-thinking-step"; - data: { id: string; title: string; status: string; items: string[] }; - } - | { - type: "data-suggestions"; - data: { options: string[] }; - }; - -interface AgentStep { - id: string; - title: string; - status: string; - items: string[]; -} - -type FriendlyError = { message: string; isSetup?: boolean }; - -function friendlyError(raw: string | number): FriendlyError { - if (typeof raw === "number") { - if (raw === 401) return { message: "Please sign in to use suggestions." }; - if (raw === 403) return { message: "You don\u2019t have permission for this." }; - if (raw === 404) return { message: "Suggestion service not found. Is the backend running?" }; - if (raw >= 500) return { message: "Something went wrong on the server. Try again." }; - return { message: "Something went wrong. Try again." }; - } - const lower = raw.toLowerCase(); - if (lower.includes("not authenticated") || lower.includes("unauthorized")) - return { message: "Please sign in to use suggestions." }; - if (lower.includes("no vision llm configured") || lower.includes("no llm configured")) - return { - message: "Configure a vision-capable model (e.g. GPT-4o, Gemini) to enable autocomplete.", - isSetup: true, - }; - if (lower.includes("does not support vision")) - return { - message: "The selected model doesn\u2019t support vision. Choose a vision-capable model.", - isSetup: true, - }; - if (lower.includes("fetch") || lower.includes("network") || lower.includes("econnrefused")) - return { message: "Can\u2019t reach the server. Check your connection." }; - return { message: "Something went wrong. Try again." }; -} - -const AUTO_DISMISS_MS = 3000; - -function StepIcon({ status }: { status: string }) { - if (status === "complete") { - return ( - - - - - ); - } - return ; -} - -export default function SuggestionPage() { - const api = useElectronAPI(); - const [options, setOptions] = useState([]); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState(null); - const [steps, setSteps] = useState([]); - const [expandedOption, setExpandedOption] = useState(null); - const abortRef = useRef(null); - - const isDesktop = !!api?.onAutocompleteContext; - - useEffect(() => { - if (!api?.onAutocompleteContext) { - setIsLoading(false); - } - }, [api]); - - useEffect(() => { - if (!error || error.isSetup) return; - const timer = setTimeout(() => { - api?.dismissSuggestion?.(); - }, AUTO_DISMISS_MS); - return () => clearTimeout(timer); - }, [error, api]); - - useEffect(() => { - if (isLoading || error || options.length > 0) return; - const timer = setTimeout(() => { - api?.dismissSuggestion?.(); - }, AUTO_DISMISS_MS); - return () => clearTimeout(timer); - }, [isLoading, error, options, api]); - - const fetchSuggestion = useCallback( - async (screenshot: string, searchSpaceId: string, appName?: string, windowTitle?: string) => { - abortRef.current?.abort(); - const controller = new AbortController(); - abortRef.current = controller; - - setIsLoading(true); - setOptions([]); - setError(null); - setSteps([]); - setExpandedOption(null); - - let token = getBearerToken(); - if (!token) { - await ensureTokensFromElectron(); - token = getBearerToken(); - } - if (!token) { - setError(friendlyError("not authenticated")); - setIsLoading(false); - return; - } - - const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - - try { - const response = await fetch(`${backendUrl}/api/v1/autocomplete/vision/stream`, { - method: "POST", - headers: { - Authorization: `Bearer ${token}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - screenshot, - search_space_id: parseInt(searchSpaceId, 10), - app_name: appName || "", - window_title: windowTitle || "", - }), - signal: controller.signal, - }); - - if (!response.ok) { - setError(friendlyError(response.status)); - setIsLoading(false); - return; - } - - if (!response.body) { - setError(friendlyError("network error")); - setIsLoading(false); - return; - } - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ""; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - const events = buffer.split(/\r?\n\r?\n/); - buffer = events.pop() || ""; - - for (const event of events) { - const lines = event.split(/\r?\n/); - for (const line of lines) { - if (!line.startsWith("data: ")) continue; - const data = line.slice(6).trim(); - if (!data || data === "[DONE]") continue; - - try { - const parsed: SSEEvent = JSON.parse(data); - if (parsed.type === "data-suggestions") { - setOptions(parsed.data.options); - } else if (parsed.type === "error") { - setError(friendlyError(parsed.errorText)); - } else if (parsed.type === "data-thinking-step") { - const { id, title, status, items } = parsed.data; - setSteps((prev) => { - const existing = prev.findIndex((s) => s.id === id); - if (existing >= 0) { - const updated = [...prev]; - updated[existing] = { id, title, status, items }; - return updated; - } - return [...prev, { id, title, status, items }]; - }); - } - } catch {} - } - } - } - } catch (err) { - if (err instanceof DOMException && err.name === "AbortError") return; - setError(friendlyError("network error")); - } finally { - setIsLoading(false); - } - }, - [] - ); - - useEffect(() => { - if (!api?.onAutocompleteContext) return; - - const cleanup = api.onAutocompleteContext((data) => { - const searchSpaceId = data.searchSpaceId || "1"; - if (data.screenshot) { - fetchSuggestion(data.screenshot, searchSpaceId, data.appName, data.windowTitle); - } - }); - - return cleanup; - }, [fetchSuggestion, api]); - - if (!isDesktop) { - return ( -
- - This page is only available in the SurfSense desktop app. - -
- ); - } - - if (error) { - if (error.isSetup) { - return ( -
-
- -
-
- Vision Model Required - {error.message} - Settings → Vision Models -
- -
- ); - } - return ( -
- {error.message} -
- ); - } - - const showLoading = isLoading && options.length === 0; - - if (showLoading) { - return ( -
-
- {steps.length === 0 && ( -
- - Preparing… -
- )} - {steps.length > 0 && ( -
- {steps.map((step) => ( -
- - - {step.title} - {step.items.length > 0 && ( - · {step.items[0]} - )} - -
- ))} -
- )} -
-
- ); - } - - const handleSelect = (text: string) => { - api?.acceptSuggestion?.(text); - }; - - const handleDismiss = () => { - api?.dismissSuggestion?.(); - }; - - const TRUNCATE_LENGTH = 120; - - if (options.length === 0) { - return ( -
- No suggestions available. -
- ); - } - - return ( -
-
- {options.map((option, index) => { - const isExpanded = expandedOption === index; - const needsTruncation = option.length > TRUNCATE_LENGTH; - const displayText = - needsTruncation && !isExpanded ? option.slice(0, TRUNCATE_LENGTH) + "…" : option; - - return ( - - )} - - ); - })} -
-
- -
-
- ); -} diff --git a/surfsense_web/app/desktop/suggestion/suggestion.css b/surfsense_web/app/desktop/suggestion/suggestion.css deleted file mode 100644 index b27fe7874..000000000 --- a/surfsense_web/app/desktop/suggestion/suggestion.css +++ /dev/null @@ -1,352 +0,0 @@ -html:has(.suggestion-body), -body:has(.suggestion-body) { - margin: 0 !important; - padding: 0 !important; - background: transparent !important; - overflow: hidden !important; - height: auto !important; - width: 100% !important; -} - -.suggestion-body { - margin: 0; - padding: 0; - background: transparent; - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; - -webkit-font-smoothing: antialiased; - user-select: none; - -webkit-app-region: no-drag; -} - -.suggestion-tooltip { - box-sizing: border-box; - background: #1e1e1e; - border: 1px solid #3c3c3c; - border-radius: 8px; - padding: 8px 12px; - margin: 4px; - max-width: 400px; - /* MAX_HEIGHT in suggestion-window.ts is 400px. Subtract 8px for margin - (4px * 2) so the tooltip + margin fits within the Electron window. - box-sizing: border-box ensures padding + border are included. */ - max-height: 392px; - box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5); - display: flex; - flex-direction: column; - overflow: hidden; -} - -.suggestion-text { - color: #d4d4d4; - font-size: 13px; - line-height: 1.45; - margin: 0 0 6px 0; - word-wrap: break-word; - white-space: pre-wrap; - overflow-y: auto; - flex: 1 1 auto; - min-height: 0; -} - -.suggestion-text::-webkit-scrollbar { - width: 5px; -} - -.suggestion-text::-webkit-scrollbar-track { - background: transparent; -} - -.suggestion-text::-webkit-scrollbar-thumb { - background: #555; - border-radius: 3px; -} - -.suggestion-text::-webkit-scrollbar-thumb:hover { - background: #777; -} - -.suggestion-actions { - display: flex; - justify-content: flex-end; - gap: 4px; - border-top: 1px solid #2a2a2a; - padding-top: 6px; - flex-shrink: 0; -} - -.suggestion-btn { - padding: 2px 8px; - border-radius: 3px; - border: 1px solid #3c3c3c; - font-family: inherit; - font-size: 10px; - font-weight: 500; - cursor: pointer; - line-height: 16px; - transition: - background 0.15s, - border-color 0.15s; -} - -.suggestion-btn-accept { - background: #2563eb; - border-color: #3b82f6; - color: #fff; -} - -.suggestion-btn-accept:hover { - background: #1d4ed8; -} - -.suggestion-btn-dismiss { - background: #2a2a2a; - color: #999; -} - -.suggestion-btn-dismiss:hover { - background: #333; - color: #ccc; -} - -.suggestion-error { - border-color: #5c2626; -} - -.suggestion-error-text { - color: #f48771; - font-size: 12px; -} - -/* --- Setup prompt (vision model not configured) --- */ - -.suggestion-setup { - display: flex; - flex-direction: row; - align-items: flex-start; - gap: 10px; - border-color: #3b2d6b; - padding: 10px 14px; -} - -.setup-icon { - flex-shrink: 0; - margin-top: 1px; -} - -.setup-content { - display: flex; - flex-direction: column; - gap: 3px; - min-width: 0; -} - -.setup-title { - font-size: 13px; - font-weight: 600; - color: #c4b5fd; -} - -.setup-message { - font-size: 11.5px; - color: #a1a1aa; - line-height: 1.4; -} - -.setup-hint { - font-size: 10.5px; - color: #7c6dac; - margin-top: 2px; -} - -.setup-dismiss { - flex-shrink: 0; - align-self: flex-start; - background: none; - border: none; - color: #6b6b7b; - font-size: 14px; - cursor: pointer; - padding: 2px 4px; - line-height: 1; - border-radius: 4px; - transition: - color 0.15s, - background 0.15s; -} - -.setup-dismiss:hover { - color: #c4b5fd; - background: rgba(124, 109, 172, 0.15); -} - -/* --- Agent activity indicator --- */ - -.agent-activity { - display: flex; - flex-direction: column; - gap: 4px; - overflow-y: auto; - max-height: 340px; -} - -.agent-activity::-webkit-scrollbar { - display: none; -} - -.activity-initial { - display: flex; - align-items: center; - gap: 8px; - padding: 2px 0; -} - -.activity-label { - color: #a1a1aa; - font-size: 12px; - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; -} - -.activity-steps { - display: flex; - flex-direction: column; - gap: 3px; -} - -.activity-step { - display: flex; - align-items: center; - gap: 6px; - min-height: 18px; -} - -.step-label { - color: #d4d4d4; - font-size: 12px; - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; -} - -.step-detail { - color: #71717a; - font-size: 11px; -} - -/* Spinner (in_progress) */ -.step-spinner { - width: 14px; - height: 14px; - flex-shrink: 0; - border: 1.5px solid #3f3f46; - border-top-color: #a78bfa; - border-radius: 50%; - animation: step-spin 0.7s linear infinite; -} - -/* Checkmark icon (complete) */ -.step-icon { - width: 14px; - height: 14px; - flex-shrink: 0; -} - -@keyframes step-spin { - to { - transform: rotate(360deg); - } -} - -/* --- Suggestion option cards --- */ - -.suggestion-options { - display: flex; - flex-direction: column; - gap: 4px; - overflow-y: auto; - flex: 1 1 auto; - min-height: 0; - margin-bottom: 6px; -} - -.suggestion-options::-webkit-scrollbar { - width: 5px; -} - -.suggestion-options::-webkit-scrollbar-track { - background: transparent; -} - -.suggestion-options::-webkit-scrollbar-thumb { - background: #555; - border-radius: 3px; -} - -.suggestion-option { - display: flex; - align-items: flex-start; - gap: 8px; - padding: 6px 8px; - border-radius: 5px; - border: 1px solid #333; - background: #262626; - cursor: pointer; - text-align: left; - font-family: inherit; - transition: - background 0.15s, - border-color 0.15s; - width: 100%; -} - -.suggestion-option:hover { - background: #2a2d3a; - border-color: #3b82f6; -} - -.option-number { - flex-shrink: 0; - width: 18px; - height: 18px; - border-radius: 50%; - background: #3f3f46; - color: #d4d4d4; - font-size: 10px; - font-weight: 600; - display: flex; - align-items: center; - justify-content: center; - margin-top: 1px; -} - -.suggestion-option:hover .option-number { - background: #2563eb; - color: #fff; -} - -.option-text { - color: #d4d4d4; - font-size: 12px; - line-height: 1.45; - word-wrap: break-word; - white-space: pre-wrap; - flex: 1 1 auto; - min-width: 0; -} - -.option-expand { - flex-shrink: 0; - background: none; - border: none; - color: #71717a; - font-size: 10px; - cursor: pointer; - padding: 0 2px; - font-family: inherit; - margin-top: 1px; -} - -.option-expand:hover { - color: #a1a1aa; -} diff --git a/surfsense_web/components/desktop/shortcut-recorder.tsx b/surfsense_web/components/desktop/shortcut-recorder.tsx index c872afaf1..50ced5313 100644 --- a/surfsense_web/components/desktop/shortcut-recorder.tsx +++ b/surfsense_web/components/desktop/shortcut-recorder.tsx @@ -36,9 +36,8 @@ export function acceleratorToDisplay(accel: string): string[] { } export const DEFAULT_SHORTCUTS = { - generalAssist: "CommandOrControl+Shift+S", - quickAsk: "CommandOrControl+Alt+S", - autocomplete: "CommandOrControl+Shift+Space", + generalAssist: "Alt+Shift+G", + quickAsk: "Alt+Shift+Q", }; // --------------------------------------------------------------------------- diff --git a/surfsense_web/types/window.d.ts b/surfsense_web/types/window.d.ts index e9f29a8f3..a8f02fd20 100644 --- a/surfsense_web/types/window.d.ts +++ b/surfsense_web/types/window.d.ts @@ -71,6 +71,7 @@ interface ElectronAPI { openExternal: (url: string) => void; getAppVersion: () => Promise; onDeepLink: (callback: (url: string) => void) => () => void; + onChatScreenCapture: (callback: (dataUrl: string) => void) => () => void; getQuickAskText: () => Promise; setQuickAskMode: (mode: string) => Promise; getQuickAskMode: () => Promise; @@ -83,19 +84,6 @@ interface ElectronAPI { requestAccessibility: () => Promise; requestScreenRecording: () => Promise; restartApp: () => Promise; - // Autocomplete - onAutocompleteContext: ( - callback: (data: { - screenshot: string; - searchSpaceId?: string; - appName?: string; - windowTitle?: string; - }) => void - ) => () => void; - acceptSuggestion: (text: string) => Promise; - dismissSuggestion: () => Promise; - setAutocompleteEnabled: (enabled: boolean) => Promise; - getAutocompleteEnabled: () => Promise; // Folder sync selectFolder: () => Promise; addWatchedFolder: (config: WatchedFolderConfig) => Promise; @@ -115,18 +103,15 @@ interface ElectronAPI { browseFiles: () => Promise; readLocalFiles: (paths: string[]) => Promise; readAgentLocalFileText: (virtualPath: string) => Promise; - writeAgentLocalFileText: ( - virtualPath: string, - content: string - ) => Promise; + writeAgentLocalFileText: (virtualPath: string, content: string) => Promise; // Auth token sync across windows getAuthTokens: () => Promise<{ bearer: string; refresh: string } | null>; setAuthTokens: (bearer: string, refresh: string) => Promise; // Keyboard shortcut configuration - getShortcuts: () => Promise<{ generalAssist: string; quickAsk: string; autocomplete: string }>; + getShortcuts: () => Promise<{ generalAssist: string; quickAsk: string }>; setShortcuts: ( - config: Partial<{ generalAssist: string; quickAsk: string; autocomplete: string }> - ) => Promise<{ generalAssist: string; quickAsk: string; autocomplete: string }>; + config: Partial<{ generalAssist: string; quickAsk: string }> + ) => Promise<{ generalAssist: string; quickAsk: string }>; // Launch on system startup getAutoLaunch: () => Promise<{ enabled: boolean; From ed0bcafe49a946bea1cc6bb02933780f111efbee Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 24 Apr 2026 19:19:04 +0200 Subject: [PATCH 157/299] Align connectors, editors, and layout with desktop context --- .../assistant-ui/connector-popup.tsx | 50 ++++--- .../components/mcp-connect-form.tsx | 14 +- .../components/mcp-config.tsx | 14 +- .../components/teams-config.tsx | 6 +- .../views/connector-edit-view.tsx | 6 +- .../views/indexing-configuration-view.tsx | 5 +- .../tabs/active-connectors-tab.tsx | 6 +- .../views/connector-accounts-list-view.tsx | 129 +++++++++-------- .../components/assistant-ui/markdown-text.tsx | 2 +- .../components/editor-panel/editor-panel.tsx | 134 +++++++++--------- .../editor/plugins/fixed-toolbar-kit.tsx | 3 +- .../components/editor/source-code-editor.tsx | 2 +- .../components/homepage/hero-section.tsx | 4 +- .../layout/ui/right-panel/RightPanel.tsx | 8 +- .../layout/ui/sidebar/DocumentsSidebar.tsx | 19 ++- .../ui/sidebar/LocalFilesystemBrowser.tsx | 10 +- .../settings/user-settings-dialog.tsx | 17 ++- .../tool-ui/generic-hitl-approval.tsx | 4 +- .../tool-ui/google-calendar/create-event.tsx | 9 +- surfsense_web/contracts/enums/toolIcons.tsx | 2 +- 20 files changed, 243 insertions(+), 201 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 66333a9ef..32943142a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -123,9 +123,9 @@ export const ConnectorIndicator = forwardRef ) : viewingMCPList ? ( - handleDisconnectFromList(connector, () => refreshConnectors())} - onAddAccount={handleAddNewMCPFromList} - addButtonText="Add New MCP Server" - /> + + handleDisconnectFromList(connector, () => refreshConnectors()) + } + onAddAccount={handleAddNewMCPFromList} + addButtonText="Add New MCP Server" + /> ) : viewingAccountsType ? ( - handleDisconnectFromList(connector, () => refreshConnectors())} - onAddAccount={() => { + + handleDisconnectFromList(connector, () => refreshConnectors()) + } + onAddAccount={() => { // Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS const oauthConnector = OAUTH_CONNECTORS.find( diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx index fc9812240..d9a740af2 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx @@ -213,13 +213,13 @@ export const MCPConnectForm: FC = ({ onSubmit, isSubmitting }) className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > {isTesting ? ( - <> - - Testing Connection... - - ) : ( - "Test Connection" - )} + <> + + Testing Connection... + + ) : ( + "Test Connection" + )}
diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx index d6f60e824..97b5de675 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx @@ -218,13 +218,13 @@ export const MCPConfig: FC = ({ connector, onConfigChange, onNam className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > {isTesting ? ( - <> - - Testing Connection... - - ) : ( - "Test Connection" - )} + <> + + Testing Connection... + + ) : ( + "Test Connection" + )}
diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx index e96ddfd29..06ce21dae 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx @@ -18,9 +18,9 @@ export const TeamsConfig: FC = () => {

Microsoft Teams Access

- Your agent can search and read messages from Teams channels you have access to, - and send messages on your behalf. Make sure you're a member of the teams - you want to interact with. + Your agent can search and read messages from Teams channels you have access to, and send + messages on your behalf. Make sure you're a member of the teams you want to interact + with.

diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index 44461c351..48f42c3b4 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -15,7 +15,7 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants"; +import { getReauthEndpoint, LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { MCPServiceConfig } from "../components/mcp-service-config"; import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; @@ -367,8 +367,8 @@ export const ConnectorEditView: FC = ({ {/* Fixed Footer - Action buttons */}
- {showDisconnectConfirm ? ( -
+ {showDisconnectConfirm ? ( +
{isLive ? "Your agent will lose access to this service." diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx index c65367e65..e8dffb3c3 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx @@ -11,7 +11,10 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import { LIVE_CONNECTOR_TYPES, type IndexingConfigState } from "../../constants/connector-constants"; +import { + type IndexingConfigState, + LIVE_CONNECTOR_TYPES, +} from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { getConnectorConfigComponent } from "../index"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx index fe9aab14f..755086ba5 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx @@ -9,7 +9,11 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; import { cn } from "@/lib/utils"; -import { COMPOSIO_CONNECTORS, LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS } from "../constants/connector-constants"; +import { + COMPOSIO_CONNECTORS, + LIVE_CONNECTOR_TYPES, + OAUTH_CONNECTORS, +} from "../constants/connector-constants"; import { getDocumentCountForConnector } from "../utils/connector-document-mapping"; import { getConnectorDisplayName } from "./all-connectors-tab"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index b3c087599..8aee7e005 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -13,7 +13,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants"; +import { getReauthEndpoint, LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { getConnectorDisplayName } from "../tabs/all-connectors-tab"; @@ -182,11 +182,14 @@ export const ConnectorAccountsListView: FC = ({
) : (
- {typeConnectors.map((connector) => { - const isIndexing = indexingConnectorIds.has(connector.id); - const connectorReauthEndpoint = getReauthEndpoint(connector); - const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true; - const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config); + {typeConnectors.map((connector) => { + const isIndexing = indexingConnectorIds.has(connector.id); + const connectorReauthEndpoint = getReauthEndpoint(connector); + const isAuthExpired = + !!connectorReauthEndpoint && connector.config?.auth_expired === true; + const isLive = + LIVE_CONNECTOR_TYPES.has(connector.connector_type) || + Boolean(connector.config?.server_config); return (
= ({

) : null}
- {isAuthExpired ? ( - - ) : isLive && onDisconnect ? ( - confirmDisconnectId === connector.id ? ( -
+ {isAuthExpired ? ( + + ) : isLive && onDisconnect ? ( + confirmDisconnectId === connector.id ? ( +
+ + +
+ ) : ( - -
+ ) ) : ( - ) - ) : ( - - )} + )}
); })} diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index a15ff1cd7..140ddcae7 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -20,7 +20,6 @@ import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; -import { useElectronAPI } from "@/hooks/use-platform"; import { Skeleton } from "@/components/ui/skeleton"; import { Table, @@ -30,6 +29,7 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; +import { useElectronAPI } from "@/hooks/use-platform"; import { cn } from "@/lib/utils"; function MarkdownCodeBlockSkeleton() { diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 1f1b41c3e..49cf99229 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -226,67 +226,68 @@ export function EditorPanelContent({ } }, [editorDoc?.source_markdown]); - const handleSave = useCallback(async (options?: { silent?: boolean }) => { - setSaving(true); - try { - if (isLocalFileMode) { - if (!localFilePath) { - throw new Error("Missing local file path"); + const handleSave = useCallback( + async (options?: { silent?: boolean }) => { + setSaving(true); + try { + if (isLocalFileMode) { + if (!localFilePath) { + throw new Error("Missing local file path"); + } + if (!electronAPI?.writeAgentLocalFileText) { + throw new Error("Local file editor is available only in desktop mode."); + } + const contentToSave = markdownRef.current; + const writeResult = await electronAPI.writeAgentLocalFileText( + localFilePath, + contentToSave + ); + if (!writeResult.ok) { + throw new Error(writeResult.error || "Failed to save local file"); + } + setEditorDoc((prev) => (prev ? { ...prev, source_markdown: contentToSave } : prev)); + setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current); + return true; } - if (!electronAPI?.writeAgentLocalFileText) { - throw new Error("Local file editor is available only in desktop mode."); + if (!searchSpaceId || !documentId) { + throw new Error("Missing document context"); } - const contentToSave = markdownRef.current; - const writeResult = await electronAPI.writeAgentLocalFileText( - localFilePath, - contentToSave + const token = getBearerToken(); + if (!token) { + toast.error("Please login to save"); + redirectToLogin(); + return; + } + const response = await authenticatedFetch( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ source_markdown: markdownRef.current }), + } ); - if (!writeResult.ok) { - throw new Error(writeResult.error || "Failed to save local file"); + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ detail: "Failed to save document" })); + throw new Error(errorData.detail || "Failed to save document"); } - setEditorDoc((prev) => - prev ? { ...prev, source_markdown: contentToSave } : prev - ); - setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current); + + setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); + setEditedMarkdown(null); + toast.success("Document saved! Reindexing in background..."); return true; + } catch (err) { + console.error("Error saving document:", err); + toast.error(err instanceof Error ? err.message : "Failed to save document"); + return false; + } finally { + setSaving(false); } - if (!searchSpaceId || !documentId) { - throw new Error("Missing document context"); - } - const token = getBearerToken(); - if (!token) { - toast.error("Please login to save"); - redirectToLogin(); - return; - } - const response = await authenticatedFetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, - { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ source_markdown: markdownRef.current }), - } - ); - - if (!response.ok) { - const errorData = await response - .json() - .catch(() => ({ detail: "Failed to save document" })); - throw new Error(errorData.detail || "Failed to save document"); - } - - setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); - setEditedMarkdown(null); - toast.success("Document saved! Reindexing in background..."); - return true; - } catch (err) { - console.error("Error saving document:", err); - toast.error(err instanceof Error ? err.message : "Failed to save document"); - return false; - } finally { - setSaving(false); - } - }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId]); + }, + [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId] + ); const isEditableType = editorDoc ? (editorRenderMode === "source_code" || @@ -383,9 +384,15 @@ export function EditorPanelContent({ )} )} - {!showEditingActions && !isLocalFileMode && editorDoc?.document_type && documentId && ( - - )} + {!showEditingActions && + !isLocalFileMode && + editorDoc?.document_type && + documentId && ( + + )}
@@ -533,11 +540,7 @@ export function EditorPanelContent({ } }} > - {downloading ? ( - - ) : ( - - )} + {downloading ? : } {downloading ? "Preparing..." : "Download .md"} @@ -564,7 +567,7 @@ export function EditorPanelContent({
) : isEditableType ? ( ; } diff --git a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx index bdda0263d..346fe0378 100644 --- a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx +++ b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx @@ -1,7 +1,6 @@ "use client"; -import { createPlatePlugin } from "platejs/react"; -import { useEditorReadOnly } from "platejs/react"; +import { createPlatePlugin, useEditorReadOnly } from "platejs/react"; import { useEditorSave } from "@/components/editor/editor-save-context"; import { FixedToolbar } from "@/components/ui/fixed-toolbar"; diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index 5cab8e5b1..9a763db27 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -1,8 +1,8 @@ "use client"; import dynamic from "next/dynamic"; -import { useEffect, useRef } from "react"; import { useTheme } from "next-themes"; +import { useEffect, useRef } from "react"; import { Spinner } from "@/components/ui/spinner"; const MonacoEditor = dynamic(() => import("@monaco-editor/react"), { diff --git a/surfsense_web/components/homepage/hero-section.tsx b/surfsense_web/components/homepage/hero-section.tsx index ce0074042..a29d02882 100644 --- a/surfsense_web/components/homepage/hero-section.tsx +++ b/surfsense_web/components/homepage/hero-section.tsx @@ -63,9 +63,9 @@ const TAB_ITEMS = [ featured: true, }, { - title: "Extreme Assist", + title: "Screen capture in chat", description: - "Get inline writing suggestions powered by your knowledge base as you type in any app.", + "Capture your screen and send it with your message so the AI can see what you see.", src: "/homepage/hero_tutorial/extreme_assist.mp4", featured: true, }, diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index c26cc9b23..04bae010c 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -72,9 +72,7 @@ export function RightPanelExpandButton() { const reportOpen = reportState.isOpen && !!reportState.reportId; const editorOpen = editorState.isOpen && - (editorState.kind === "document" - ? !!editorState.documentId - : !!editorState.localFilePath); + (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen; @@ -116,9 +114,7 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { const reportOpen = reportState.isOpen && !!reportState.reportId; const editorOpen = editorState.isOpen && - (editorState.kind === "document" - ? !!editorState.documentId - : !!editorState.localFilePath); + (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; useEffect(() => { diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 5819dcef4..e88478259 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -7,8 +7,8 @@ import { ChevronRight, FileText, Folder, - FolderPlus, FolderClock, + FolderPlus, Laptop, Lock, Paperclip, @@ -63,6 +63,7 @@ import { } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; +import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; import { DropdownMenu, DropdownMenuContent, @@ -71,7 +72,6 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; -import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; import { Input } from "@/components/ui/input"; import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; @@ -525,7 +525,9 @@ function AuthenticatedDocumentsSidebar({ if (!electronAPI) return; const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; - const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id); + const matched = watchedFolders.find( + (wf: WatchedFolderEntry) => wf.rootFolderId === folder.id + ); if (!matched) { toast.error("This folder is not being watched"); return; @@ -555,7 +557,9 @@ function AuthenticatedDocumentsSidebar({ if (!electronAPI) return; const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; - const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id); + const matched = watchedFolders.find( + (wf: WatchedFolderEntry) => wf.rootFolderId === folder.id + ); if (!matched) { toast.error("This folder is not being watched"); return; @@ -988,7 +992,8 @@ function AuthenticatedDocumentsSidebar({ }, [open, onOpenChange, isMobile, setRightPanelCollapsed]); const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings; - const currentFilesystemTab = filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; + const currentFilesystemTab = + filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; const cloudContent = ( <> @@ -1401,8 +1406,8 @@ function AuthenticatedDocumentsSidebar({ Trust this workspace? - Local mode can read and edit files inside the folders you select. Continue only if - you trust this workspace and its contents. + Local mode can read and edit files inside the folders you select. Continue only if you + trust this workspace and its contents. {pendingLocalPath && ( diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 5b08f2e37..93227054b 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -273,7 +273,10 @@ export function LocalFilesystemBrowser({ const mount = mountByRootKey.get(rootKey); if (!state || state.loading) { return ( -
+
Loading {getFolderDisplayName(rootPath)}...
@@ -281,7 +284,10 @@ export function LocalFilesystemBrowser({ } if (state.error) { return ( -
+

Failed to load local folder

{state.error}

diff --git a/surfsense_web/components/settings/user-settings-dialog.tsx b/surfsense_web/components/settings/user-settings-dialog.tsx index cc36392ae..6740aad92 100644 --- a/surfsense_web/components/settings/user-settings-dialog.tsx +++ b/surfsense_web/components/settings/user-settings-dialog.tsx @@ -1,7 +1,16 @@ "use client"; import { useAtom } from "jotai"; -import { Brain, CircleUser, Globe, Keyboard, KeyRound, Monitor, ReceiptText, Sparkles } from "lucide-react"; +import { + Brain, + CircleUser, + Globe, + Keyboard, + KeyRound, + Monitor, + ReceiptText, + Sparkles, +} from "lucide-react"; import dynamic from "next/dynamic"; import { useTranslations } from "next-intl"; import { useMemo } from "react"; @@ -53,9 +62,9 @@ const DesktopContent = dynamic( ); const DesktopShortcutsContent = dynamic( () => - import("@/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent").then( - (m) => ({ default: m.DesktopShortcutsContent }) - ), + import( + "@/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent" + ).then((m) => ({ default: m.DesktopShortcutsContent })), { ssr: false } ); const MemoryContent = dynamic( diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index c83bf55d5..ceb1d0209 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -118,7 +118,9 @@ function GenericApprovalCard({ setProcessing(); onDecision({ type: "approve" }); connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch(() => { - toast.error("Failed to save 'Always Allow' preference. The tool will still require approval next time."); + toast.error( + "Failed to save 'Always Allow' preference. The tool will still require approval next time." + ); }); }, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]); diff --git a/surfsense_web/components/tool-ui/google-calendar/create-event.tsx b/surfsense_web/components/tool-ui/google-calendar/create-event.tsx index 9427c989b..523be31f6 100644 --- a/surfsense_web/components/tool-ui/google-calendar/create-event.tsx +++ b/surfsense_web/components/tool-ui/google-calendar/create-event.tsx @@ -2,7 +2,14 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useSetAtom } from "jotai"; -import { ClockIcon, CornerDownLeftIcon, GlobeIcon, MapPinIcon, Pencil, UsersIcon } from "lucide-react"; +import { + ClockIcon, + CornerDownLeftIcon, + GlobeIcon, + MapPinIcon, + Pencil, + UsersIcon, +} from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 3bc639d33..bc63bc1b0 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,8 +1,8 @@ import { BookOpen, Brain, - FileUser, FileText, + FileUser, Film, Globe, ImageIcon, From 39f56a86d60a09406fb066f67f664a3d2569d2b9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 24 Apr 2026 22:56:04 +0530 Subject: [PATCH 158/299] refactor(DocumentsFilters): update input styling and button positioning for improved UI consistency --- surfsense_web/components/documents/DocumentsFilters.tsx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/surfsense_web/components/documents/DocumentsFilters.tsx b/surfsense_web/components/documents/DocumentsFilters.tsx index a2ebe83b9..57e6479cb 100644 --- a/surfsense_web/components/documents/DocumentsFilters.tsx +++ b/surfsense_web/components/documents/DocumentsFilters.tsx @@ -226,13 +226,13 @@ export function DocumentsFilters({ {/* Search Input */}
-
+
onSearch(e.target.value)} placeholder="Search docs" @@ -242,7 +242,7 @@ export function DocumentsFilters({ {Boolean(searchValue) && ( )} diff --git a/surfsense_web/components/chat-comments/comment-sheet/comment-sheet.tsx b/surfsense_web/components/chat-comments/comment-sheet/comment-sheet.tsx index d483ab261..8db45f764 100644 --- a/surfsense_web/components/chat-comments/comment-sheet/comment-sheet.tsx +++ b/surfsense_web/components/chat-comments/comment-sheet/comment-sheet.tsx @@ -1,6 +1,6 @@ "use client"; -import { MessageSquare } from "lucide-react"; +import { MessageCircleReply } from "lucide-react"; import { Drawer, DrawerContent, @@ -30,7 +30,7 @@ export function CommentSheet({ - + Comments {commentCount > 0 && ( @@ -56,7 +56,7 @@ export function CommentSheet({ > - + Comments {commentCount > 0 && ( diff --git a/surfsense_web/components/chat-comments/comment-thread/comment-thread.tsx b/surfsense_web/components/chat-comments/comment-thread/comment-thread.tsx index e47531129..7929716bb 100644 --- a/surfsense_web/components/chat-comments/comment-thread/comment-thread.tsx +++ b/surfsense_web/components/chat-comments/comment-thread/comment-thread.tsx @@ -1,6 +1,6 @@ "use client"; -import { ChevronDown, ChevronRight, MessageSquare } from "lucide-react"; +import { ChevronDown, ChevronRight, MessageCircleReply } from "lucide-react"; import { useState } from "react"; import { Button } from "@/components/ui/button"; import { CommentComposer } from "../comment-composer/comment-composer"; @@ -143,7 +143,7 @@ export function CommentThread({
) : ( )} @@ -155,7 +155,7 @@ export function CommentThread({ {!hasReplies && !isReplyComposerOpen && (
diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 8fd3e4ce5..6dd645a42 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -321,7 +321,7 @@ export function EditorPanelContent({
-
+

{displayTitle}

@@ -352,6 +352,12 @@ export function EditorPanelContent({ ) : ( <> + {!isLocalFileMode && editorDoc?.document_type && documentId && ( + + )}
) : (
-
+

{displayTitle}

@@ -422,6 +425,12 @@ export function EditorPanelContent({ ) : ( <> + {!isLocalFileMode && editorDoc?.document_type && documentId && ( + + )} )} - {!isLocalFileMode && editorDoc?.document_type && documentId && ( - - )} )}
diff --git a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx index 65946487e..fa05559d7 100644 --- a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx @@ -14,7 +14,7 @@ import { Inbox, LayoutGrid, ListFilter, - MessageSquare, + MessageCircleReply, Search, X, } from "lucide-react"; @@ -847,7 +847,7 @@ export function InboxSidebarContent({ - + {t("comments") || "Comments"} {formatInboxCount(comments.unreadCount)} @@ -1032,7 +1032,7 @@ export function InboxSidebarContent({ ) : (
{activeTab === "comments" ? ( - + ) : ( )} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 385a16aec..130637c96 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -8,7 +8,7 @@ import { ChevronLeft, ChevronRight, ChevronUp, - Edit3, + Pencil, ImageIcon, Layers, Plus, @@ -923,7 +923,7 @@ export function ModelSelector({ className="size-7 rounded-md hover:bg-muted opacity-0 group-hover:opacity-100 transition-opacity" onClick={(e) => handleEditItem(e, item)} > - + )} {isSelected && } diff --git a/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx b/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx index 55bcc52a9..ce3a83791 100644 --- a/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx +++ b/surfsense_web/components/public-chat-snapshots/public-chat-snapshot-row.tsx @@ -79,8 +79,11 @@ export function PublicChatSnapshotRow({ variant="ghost" size="icon" className={cn( - "absolute right-0 h-6 w-6 shrink-0 hover:bg-transparent", - dropdownOpen ? "opacity-100" : "sm:opacity-0 sm:group-hover:opacity-100" + "absolute right-0 h-6 w-6 shrink-0", + "hover:bg-accent", + dropdownOpen + ? "opacity-100 bg-accent hover:bg-accent" + : "sm:opacity-0 sm:group-hover:opacity-100" )} > diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx index f7a2fb824..988befdd0 100644 --- a/surfsense_web/components/settings/agent-model-manager.tsx +++ b/surfsense_web/components/settings/agent-model-manager.tsx @@ -4,10 +4,9 @@ import { useAtomValue } from "jotai"; import { AlertCircle, Dot, - Edit3, FileText, Info, - MessageSquareQuote, + Pencil, RefreshCw, Trash2, } from "lucide-react"; @@ -288,7 +287,7 @@ export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { onClick={() => openEditDialog(config)} className="h-7 w-7 rounded-lg text-muted-foreground hover:text-foreground" > - + Edit @@ -323,7 +322,6 @@ export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { variant="secondary" className="text-[10px] px-1.5 py-0.5 border-0 text-muted-foreground bg-muted" > - Citations )} diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index fb28e5b1c..f5f128f80 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Edit3, Info, RefreshCw, Trash2 } from "lucide-react"; +import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; import { useMemo, useState } from "react"; import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; import { @@ -116,8 +116,8 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { return (
- {/* Header */} -
+ {/* Header actions */} +
Edit diff --git a/surfsense_web/components/settings/roles-manager.tsx b/surfsense_web/components/settings/roles-manager.tsx index 7f59ecd66..e7dadc20f 100644 --- a/surfsense_web/components/settings/roles-manager.tsx +++ b/surfsense_web/components/settings/roles-manager.tsx @@ -4,21 +4,25 @@ import { useQuery } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; import { Bot, - ChevronDown, - Edit2, + ChevronRight, + ScanEye, + Pencil, FileText, - Globe, + Earth, + Image, Logs, type LucideIcon, - MessageCircle, + MessageCircleReply, MessageSquare, Mic, MoreHorizontal, - Plug, + Unplug, Settings, Shield, + SlidersHorizontal, Trash2, Users, + Video, } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; @@ -88,7 +92,7 @@ const CATEGORY_CONFIG: Record< }, comments: { label: "Comments", - icon: MessageCircle, + icon: MessageCircleReply, description: "Add annotations to documents", order: 3, }, @@ -98,6 +102,24 @@ const CATEGORY_CONFIG: Record< description: "Configure AI model settings", order: 4, }, + image_generations: { + label: "Image Models", + icon: Image, + description: "Configure image generation model settings", + order: 4.1, + }, + vision_configs: { + label: "Vision Models", + icon: ScanEye, + description: "Configure vision model settings", + order: 4.2, + }, + video_presentations: { + label: "Video Presentations", + icon: Video, + description: "Generate and manage video presentations", + order: 4.3, + }, podcasts: { label: "Podcasts", icon: Mic, @@ -105,8 +127,8 @@ const CATEGORY_CONFIG: Record< order: 5, }, connectors: { - label: "Integrations", - icon: Plug, + label: "Connectors", + icon: Unplug, description: "Connect external data sources", order: 6, }, @@ -136,10 +158,16 @@ const CATEGORY_CONFIG: Record< }, public_sharing: { label: "Public Chat Sharing", - icon: Globe, + icon: Earth, description: "Share chats publicly via links", order: 11, }, + general: { + label: "General", + icon: SlidersHorizontal, + description: "General search space permissions", + order: 12, + }, }; const ACTION_LABELS: Record = { @@ -434,12 +462,11 @@ function RolesContent({ return (
-
- +
{!role.is_system_role && ( -
+
e.stopPropagation()}>
)} - +
{isExpanded && ( @@ -659,52 +682,30 @@ function PermissionsEditor({ return (
-
- +
e.stopPropagation()} onCheckedChange={() => onToggleCategory(category)} aria-label={`Select all ${config.label} permissions`} /> - +
@@ -726,7 +727,7 @@ function PermissionsEditor({ > diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 81528c86a..8abfa4774 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Edit3, Info, RefreshCw, Trash2 } from "lucide-react"; +import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; import { useMemo, useState } from "react"; import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; import { deleteVisionLLMConfigMutationAtom } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; @@ -121,7 +121,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { return (
-
+
Edit From 456dd7417cc4b8a497b60d49347f96f29765831b Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Sun, 26 Apr 2026 02:46:43 -0700 Subject: [PATCH 172/299] fix(connectors): refresh Redis heartbeat during long Phase 1 indexing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #1295 The connector indexing route's `_run_indexing_with_notifications` set the Redis heartbeat key once at the start of indexing and relied on `on_heartbeat_callback` (only fired in Phase 2 per-document loops) to refresh it. The GitHub connector's Phase 1 runs `gitingest` as a blocking subprocess via `asyncio.to_thread`, so for any repo larger than the 2-minute TTL, the key expires before Phase 2 starts. The `cleanup_stale_indexing_notifications_task` then marks the document as failed with the misleading "Sync was interrupted unexpectedly. Please retry." message — even though the indexing thread is still running and gitingest's own subprocess timeout is 900 seconds. Add a background asyncio coroutine that refreshes the Redis key every 60 seconds for the duration of the indexing call. Same pattern already in use at app/tasks/celery_tasks/document_tasks.py:_run_heartbeat_loop, just adapted to use the route's get_heartbeat_redis_client() and _get_heartbeat_key() helpers. Cancellation runs in the `finally` block BEFORE the heartbeat-key delete so the loop cannot race and re-create the key after we have deleted it. The new `HEARTBEAT_REFRESH_INTERVAL = 60` constant mirrors the celery task module's value. --- .../routes/search_source_connectors_routes.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index bb20da65d..c10838ed6 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -81,6 +81,38 @@ _heartbeat_redis_client: redis.Redis | None = None # Redis key TTL - notification is stale if no heartbeat in this time HEARTBEAT_TTL_SECONDS = 120 # 2 minutes +# How often the background loop refreshes the Redis key. Must be < TTL so +# the key cannot expire between refreshes when the indexing function is +# doing blocking work (e.g. gitingest in Phase 1) that doesn't trigger +# on_heartbeat_callback. +HEARTBEAT_REFRESH_INTERVAL = 60 + + +async def _run_indexing_heartbeat_loop(notification_id: int) -> None: + """Background coroutine that refreshes the Redis heartbeat every + HEARTBEAT_REFRESH_INTERVAL seconds while a connector indexing task is + running. + + Mirrors `_run_heartbeat_loop` in app/tasks/celery_tasks/document_tasks.py. + Cancelled via heartbeat_task.cancel() when the indexing call returns + (success or failure). If the worker dies, the coroutine dies with it + and the Redis key expires naturally on its TTL. + """ + key = _get_heartbeat_key(notification_id) + try: + while True: + await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL) + try: + get_heartbeat_redis_client().setex( + key, HEARTBEAT_TTL_SECONDS, "alive" + ) + except Exception as e: + logger.warning( + f"Failed to refresh Redis heartbeat for notification " + f"{notification_id}: {e}" + ) + except asyncio.CancelledError: + pass # Normal cancellation when the indexing task completes def get_heartbeat_redis_client() -> redis.Redis: @@ -1457,6 +1489,7 @@ async def _run_indexing_with_notifications( notification = None connector_lock_acquired = False + heartbeat_task: asyncio.Task | None = None # Track indexed count for retry notifications and heartbeat current_indexed_count = 0 @@ -1502,6 +1535,16 @@ async def _run_indexing_with_notifications( except Exception as e: logger.warning(f"Failed to set initial Redis heartbeat: {e}") + # Start a background coroutine that refreshes the + # heartbeat every HEARTBEAT_REFRESH_INTERVAL seconds. + # Without this the cleanup_stale_indexing_notifications + # task can mark the doc failed when on_heartbeat_callback + # doesn't fire — for example during the GitHub + # connector's Phase 1 gitingest blocking call (#1295). + heartbeat_task = asyncio.create_task( + _run_indexing_heartbeat_loop(notification.id) + ) + # Update notification to fetching stage if notification: await NotificationService.connector_indexing.notify_indexing_progress( @@ -1792,6 +1835,13 @@ async def _run_indexing_with_notifications( except Exception as notif_error: logger.error(f"Failed to update notification: {notif_error!s}") finally: + # Stop the background heartbeat refresher BEFORE deleting the + # Redis key, so the loop cannot race and re-create the key + # after we delete it. + if heartbeat_task is not None: + heartbeat_task.cancel() + with suppress(Exception): + await asyncio.gather(heartbeat_task, return_exceptions=True) # Clean up Redis heartbeat key when task completes (success or failure) if notification: try: From a08066e2f917993d4c4f903d0dc9bc3ab5c86c93 Mon Sep 17 00:00:00 2001 From: Matt Van Horn Date: Sun, 26 Apr 2026 08:59:16 -0700 Subject: [PATCH 173/299] style: ruff format Redis heartbeat refresh in connectors route --- .../app/routes/search_source_connectors_routes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index c10838ed6..9c93d4e42 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -103,9 +103,7 @@ async def _run_indexing_heartbeat_loop(notification_id: int) -> None: while True: await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL) try: - get_heartbeat_redis_client().setex( - key, HEARTBEAT_TTL_SECONDS, "alive" - ) + get_heartbeat_redis_client().setex(key, HEARTBEAT_TTL_SECONDS, "alive") except Exception as e: logger.warning( f"Failed to refresh Redis heartbeat for notification " From e95e417cc8b74c97bc8555e4e1b222338b62c317 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 01:19:05 +0530 Subject: [PATCH 174/299] feat: implement smooth scrolling for provider sidebar in ModelSelector --- .../components/new-chat/model-selector.tsx | 92 ++++++++++++++++--- 1 file changed, 80 insertions(+), 12 deletions(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 130637c96..3f5a5fa8c 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -320,6 +320,30 @@ export function ModelSelector({ [isMobile] ); + const scrollProviderSidebar = useCallback( + (direction: "backward" | "forward") => { + const el = providerSidebarRef.current; + if (!el) return; + const delta = isMobile + ? Math.max(56, Math.floor(el.clientWidth * 0.5)) + : Math.max(44, Math.floor(el.clientHeight * 0.4)); + + if (isMobile) { + el.scrollBy({ + left: direction === "backward" ? -delta : delta, + behavior: "smooth", + }); + return; + } + + el.scrollBy({ + top: direction === "backward" ? -delta : delta, + behavior: "smooth", + }); + }, + [isMobile] + ); + // Cmd/Ctrl+M shortcut (desktop only) useEffect(() => { if (isMobile) return; @@ -716,17 +740,40 @@ export function ModelSelector({ return (
- {!isMobile && sidebarScrollPos !== "top" && ( -
- + {!isMobile && ( +
+
)} - {isMobile && sidebarScrollPos !== "top" && ( -
+ {isMobile && ( +
)} @@ -802,13 +849,34 @@ export function ModelSelector({ ); })}
- {!isMobile && sidebarScrollPos !== "bottom" && ( -
- + {!isMobile && ( +
+
)} - {isMobile && sidebarScrollPos !== "bottom" && ( -
+ {isMobile && ( +
)} From f7fa96ccd01bcd2df89087c4e92151af5bdc5b25 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 02:53:26 +0530 Subject: [PATCH 175/299] feat(sidebar): enhance DocumentsSidebar with tooltip support for folder addition, improving user feedback on folder limits --- .../layout/ui/sidebar/DocumentsSidebar.tsx | 48 ++++++++++++++----- surfsense_web/components/ui/tooltip.tsx | 19 ++++---- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 5819dcef4..ce9b80d49 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -1211,18 +1211,42 @@ function AuthenticatedDocumentsSidebar({ orientation="vertical" className="data-[orientation=vertical]:h-3 self-center bg-border" /> - + {electronAPI ? ( + + + + + + + + {canAddMoreLocalRoots + ? "Add folder" + : `You can add up to ${MAX_LOCAL_FILESYSTEM_ROOTS} folders`} + + + ) : ( + + )}
diff --git a/surfsense_web/components/ui/tooltip.tsx b/surfsense_web/components/ui/tooltip.tsx index bcf1c72f8..c1469156d 100644 --- a/surfsense_web/components/ui/tooltip.tsx +++ b/surfsense_web/components/ui/tooltip.tsx @@ -6,20 +6,19 @@ import { useEffect, useState } from "react"; import { cn } from "@/lib/utils"; -const MOBILE_BREAKPOINT = 768; - -function useIsTouchDevice() { - const [isTouch, setIsTouch] = useState(false); +function useCanHover() { + const [canHover, setCanHover] = useState(false); useEffect(() => { - const mql = window.matchMedia(`(max-width: ${MOBILE_BREAKPOINT - 1}px)`); - const update = () => setIsTouch(mql.matches); + // Hover-capable pointers are a better cross-platform signal than viewport width. + const mql = window.matchMedia("(hover: hover) and (pointer: fine)"); + const update = () => setCanHover(mql.matches); update(); mql.addEventListener("change", update); return () => mql.removeEventListener("change", update); }, []); - return isTouch; + return canHover; } function TooltipProvider({ @@ -42,14 +41,14 @@ function Tooltip({ onOpenChange, ...props }: React.ComponentProps) { - const isMobile = useIsTouchDevice(); + const canHover = useCanHover(); return ( From dbdeaa1bcffa4dd8ca60df2e1f706c656f76279e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 04:03:53 +0530 Subject: [PATCH 176/299] feat(sidebar): add loading skeletons to DocumentsSidebar and LocalFilesystemBrowser during data fetching --- .../layout/ui/sidebar/DocumentsSidebar.tsx | 116 +++++++++++------- .../ui/sidebar/LocalFilesystemBrowser.tsx | 94 ++++++++++++-- 2 files changed, 162 insertions(+), 48 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index ce9b80d49..f4c5c03ac 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -73,6 +73,7 @@ import { } from "@/components/ui/dropdown-menu"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; import { Input } from "@/components/ui/input"; +import { Skeleton } from "@/components/ui/skeleton"; import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; @@ -99,6 +100,32 @@ const NON_DELETABLE_DOCUMENT_TYPES: readonly string[] = ["SURFSENSE_DOCS"]; const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1"; const MAX_LOCAL_FILESYSTEM_ROOTS = 5; +function CloudDocumentsSkeleton() { + const rows = [ + { id: "row-1", widthClass: "w-44" }, + { id: "row-2", widthClass: "w-32" }, + { id: "row-3", widthClass: "w-32" }, + { id: "row-4", widthClass: "w-44" }, + { id: "row-5", widthClass: "w-32" }, + { id: "row-6", widthClass: "w-32" }, + { id: "row-7", widthClass: "w-44" }, + { id: "row-8", widthClass: "w-32" }, + ]; + + return ( +
+
+ {rows.map((row) => ( +
+ + +
+ ))} +
+
+ ); +} + type FilesystemSettings = { mode: "cloud" | "desktop_local_folder"; localRootPaths: string[]; @@ -407,8 +434,8 @@ function AuthenticatedDocumentsSidebar({ ); // Zero queries for tree data - const [zeroFolders] = useQuery(queries.folders.bySpace({ searchSpaceId })); - const [zeroAllDocs] = useQuery(queries.documents.bySpace({ searchSpaceId })); + const [zeroFolders, zeroFoldersResult] = useQuery(queries.folders.bySpace({ searchSpaceId })); + const [zeroAllDocs, zeroAllDocsResult] = useQuery(queries.documents.bySpace({ searchSpaceId })); const [agentCreatedDocs, setAgentCreatedDocs] = useAtom(agentCreatedDocumentsAtom); const treeFolders: FolderDisplay[] = useMemo( @@ -989,6 +1016,9 @@ function AuthenticatedDocumentsSidebar({ const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings; const currentFilesystemTab = filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; + const showCloudSkeleton = + currentFilesystemTab === "cloud" && + (zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete"); const cloudContent = ( <> @@ -1101,45 +1131,49 @@ function AuthenticatedDocumentsSidebar({
)} - { - openEditorPanel({ - documentId: doc.id, - searchSpaceId, - title: doc.title, - }); - }} - onEditDocument={(doc) => { - openEditorPanel({ - documentId: doc.id, - searchSpaceId, - title: doc.title, - }); - }} - onDeleteDocument={(doc) => handleDeleteDocument(doc.id)} - onMoveDocument={handleMoveDocument} - onExportDocument={handleExportDocument} - onVersionHistory={(doc) => setVersionDocId(doc.id)} - activeTypes={activeTypes} - onDropIntoFolder={handleDropIntoFolder} - onReorderFolder={handleReorderFolder} - watchedFolderIds={watchedFolderIds} - onRescanFolder={handleRescanFolder} - onStopWatchingFolder={handleStopWatching} - onExportFolder={handleExportFolder} - /> + {showCloudSkeleton ? ( + + ) : ( + { + openEditorPanel({ + documentId: doc.id, + searchSpaceId, + title: doc.title, + }); + }} + onEditDocument={(doc) => { + openEditorPanel({ + documentId: doc.id, + searchSpaceId, + title: doc.title, + }); + }} + onDeleteDocument={(doc) => handleDeleteDocument(doc.id)} + onMoveDocument={handleMoveDocument} + onExportDocument={handleExportDocument} + onVersionHistory={(doc) => setVersionDocId(doc.id)} + activeTypes={activeTypes} + onDropIntoFolder={handleDropIntoFolder} + onReorderFolder={handleReorderFolder} + watchedFolderIds={watchedFolderIds} + onRescanFolder={handleRescanFolder} + onStopWatchingFolder={handleStopWatching} + onExportFolder={handleExportFolder} + /> + )}
diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 5b08f2e37..30e532896 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -1,8 +1,9 @@ "use client"; import { ChevronDown, ChevronRight, FileText, Folder } from "lucide-react"; -import { useCallback, useEffect, useMemo, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { DEFAULT_EXCLUDE_PATTERNS } from "@/components/sources/FolderWatchDialog"; +import { Skeleton } from "@/components/ui/skeleton"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; @@ -39,6 +40,8 @@ type LocalRootMount = { rootPath: string; }; +type MountLoadStatus = "idle" | "loading" | "complete" | "error"; + const getFolderDisplayName = (rootPath: string): string => rootPath.split(/[\\/]/).at(-1) || rootPath; @@ -79,6 +82,10 @@ export function LocalFilesystemBrowser({ const [rootStateMap, setRootStateMap] = useState>({}); const [expandedFolderKeys, setExpandedFolderKeys] = useState>(new Set()); const [mountByRootKey, setMountByRootKey] = useState>(new Map()); + const [mountStatus, setMountStatus] = useState("idle"); + const [mountRefreshInFlight, setMountRefreshInFlight] = useState(false); + const hasLoadedMountsOnceRef = useRef(false); + const hasResolvedAtLeastOneRootRef = useRef(false); const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); const isWindowsPlatform = electronAPI?.versions.platform === "win32"; @@ -139,23 +146,44 @@ export function LocalFilesystemBrowser({ useEffect(() => { if (!electronAPI?.getAgentFilesystemMounts) { + setMountStatus("error"); setMountByRootKey(new Map()); return; } let cancelled = false; + const isInitialMountLoad = !hasLoadedMountsOnceRef.current; + if (isInitialMountLoad) { + setMountStatus("loading"); + } else { + setMountRefreshInFlight(true); + } void electronAPI .getAgentFilesystemMounts() .then((mounts: LocalRootMount[]) => { if (cancelled) return; + const knownRootKeys = new Set( + rootPaths.map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) + ); const next = new Map(); for (const entry of mounts) { - next.set(normalizeRootPathForLookup(entry.rootPath, isWindowsPlatform), entry.mount); + const normalizedRootKey = normalizeRootPathForLookup(entry.rootPath, isWindowsPlatform); + if (!knownRootKeys.has(normalizedRootKey)) continue; + next.set(normalizedRootKey, entry.mount); } setMountByRootKey(next); + setMountStatus("complete"); + hasLoadedMountsOnceRef.current = true; }) .catch(() => { if (cancelled) return; - setMountByRootKey(new Map()); + if (isInitialMountLoad) { + setMountByRootKey(new Map()); + setMountStatus("error"); + } + }) + .finally(() => { + if (cancelled) return; + setMountRefreshInFlight(false); }); return () => { cancelled = true; @@ -265,6 +293,43 @@ export function LocalFilesystemBrowser({ ); } + const allRootsLoaded = rootPaths.every((rootPath) => { + const state = rootStateMap[rootPath]; + return !!state && !state.loading; + }); + const mountsSettled = mountStatus === "complete" || mountStatus === "error"; + if (allRootsLoaded && mountsSettled && rootPaths.length > 0) { + hasResolvedAtLeastOneRootRef.current = true; + } + const showInitialLoading = + !hasResolvedAtLeastOneRootRef.current && (!allRootsLoaded || !mountsSettled); + + if (showInitialLoading) { + const rows = [ + { id: "local-row-1", widthClass: "w-44" }, + { id: "local-row-2", widthClass: "w-32" }, + { id: "local-row-3", widthClass: "w-32" }, + { id: "local-row-4", widthClass: "w-44" }, + { id: "local-row-5", widthClass: "w-32" }, + { id: "local-row-6", widthClass: "w-32" }, + { id: "local-row-7", widthClass: "w-44" }, + { id: "local-row-8", widthClass: "w-32" }, + ]; + + return ( +
+
+ {rows.map((row) => ( +
+ + +
+ ))} +
+
+ ); + } + return (
{treeByRoot.map(({ rootPath, rootNode, matchCount, totalCount }) => { @@ -273,9 +338,11 @@ export function LocalFilesystemBrowser({ const mount = mountByRootKey.get(rootKey); if (!state || state.loading) { return ( -
- - Loading {getFolderDisplayName(rootPath)}... +
+
+ + Loading {getFolderDisplayName(rootPath)}... +
); } @@ -291,11 +358,24 @@ export function LocalFilesystemBrowser({ return (
{mount ? renderFolder(rootNode, 0, mount) : null} - {!mount && ( + {!mount && (mountRefreshInFlight || mountStatus === "loading") && ( +
+
+ + Loading {getFolderDisplayName(rootPath)}... +
+
+ )} + {!mount && mountStatus === "complete" && !mountRefreshInFlight && (
Unable to resolve mounted root for this folder.
)} + {!mount && mountStatus === "error" && ( +
+ Failed to resolve local folder mounts. +
+ )} {isEmpty && (
No supported files found in this folder. From 95511f0915afa3cb36af3e23af217416215c075d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 19:58:12 +0530 Subject: [PATCH 177/299] feat(sidebar): implement canonicalize roots, authoritative mount handling & preserved incremental UX for local folder mode --- .../middleware/test_filesystem_backends.py | 2 + .../test_multi_root_local_folder_backend.py | 9 +++ .../src/modules/agent-filesystem.ts | 26 ++++--- .../layout/ui/sidebar/DocumentsSidebar.tsx | 72 ++++++++++++++++--- .../ui/sidebar/LocalFilesystemBrowser.tsx | 11 +-- 5 files changed, 98 insertions(+), 22 deletions(-) diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py index 9600b7e05..98996d6bc 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py @@ -30,6 +30,7 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P backend = resolver(_RuntimeStub()) assert isinstance(backend, MultiRootLocalFolderBackend) + assert backend.list_mounts() == ("tmp",) def test_backend_resolver_uses_cloud_mode_by_default(): @@ -57,3 +58,4 @@ def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path backend = resolver(_RuntimeStub()) assert isinstance(backend, MultiRootLocalFolderBackend) + assert backend.list_mounts() == ("resume", "notes") diff --git a/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py index 7afb47e26..43a671178 100644 --- a/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_multi_root_local_folder_backend.py @@ -26,3 +26,12 @@ def test_mount_ids_preserve_client_mapping_order(tmp_path: Path) -> None: ) assert backend.list_mounts() == ("pc_backups", "pc_backups_2", "notes_2026") + + +def test_mount_id_is_authoritative_not_folder_name(tmp_path: Path) -> None: + root = tmp_path / "Resume Folder" + root.mkdir() + + backend = MultiRootLocalFolderBackend((("custom_resume_mount", str(root)),)) + + assert backend.list_mounts() == ("custom_resume_mount",) diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index 6db5fd6f7..eb8d385b5 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -1,5 +1,5 @@ import { app, dialog } from "electron"; -import { access, mkdir, readFile, writeFile } from "node:fs/promises"; +import { access, mkdir, readFile, realpath, writeFile } from "node:fs/promises"; import { dirname, isAbsolute, join, relative, resolve } from "node:path"; export type AgentFilesystemMode = "cloud" | "desktop_local_folder"; @@ -25,16 +25,26 @@ function getDefaultSettings(): AgentFilesystemSettings { }; } -function normalizeLocalRootPaths(paths: unknown): string[] { +async function canonicalizeRootPath(pathValue: string): Promise { + const resolvedPath = resolve(pathValue); + try { + return await realpath(resolvedPath); + } catch { + return resolvedPath; + } +} + +async function normalizeLocalRootPaths(paths: unknown): Promise { if (!Array.isArray(paths)) { return []; } const uniquePaths = new Set(); - for (const path of paths) { - if (typeof path !== "string") continue; - const trimmed = path.trim(); + for (const rawPath of paths) { + if (typeof rawPath !== "string") continue; + const trimmed = rawPath.trim(); if (!trimmed) continue; - uniquePaths.add(trimmed); + const canonicalRootPath = await canonicalizeRootPath(trimmed); + uniquePaths.add(canonicalRootPath); if (uniquePaths.size >= MAX_LOCAL_ROOTS) { break; } @@ -51,7 +61,7 @@ export async function getAgentFilesystemSettings(): Promise(null); const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); const [pendingLocalPath, setPendingLocalPath] = useState(null); + const [draggedLocalRootPath, setDraggedLocalRootPath] = useState(null); const [watchedFolderIds, setWatchedFolderIds] = useState>(new Set()); const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom); const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom); @@ -246,7 +247,7 @@ function AuthenticatedDocumentsSidebar({ const applyLocalRootPath = useCallback( async (path: string) => { if (!electronAPI?.setAgentFilesystemSettings) return; - const nextLocalRootPaths = [...localRootPaths, path] + const nextLocalRootPaths = [path, ...localRootPaths] .filter((rootPath, index, allPaths) => allPaths.indexOf(rootPath) === index) .slice(0, MAX_LOCAL_FILESYSTEM_ROOTS); if (nextLocalRootPaths.length === localRootPaths.length) return; @@ -259,6 +260,26 @@ function AuthenticatedDocumentsSidebar({ [electronAPI, localRootPaths] ); + const handleReorderFilesystemRoots = useCallback( + async (draggedPath: string, targetPath: string) => { + if (!electronAPI?.setAgentFilesystemSettings) return; + if (draggedPath === targetPath) return; + const fromIndex = localRootPaths.indexOf(draggedPath); + const toIndex = localRootPaths.indexOf(targetPath); + if (fromIndex < 0 || toIndex < 0) return; + const nextLocalRootPaths = [...localRootPaths]; + const [movedPath] = nextLocalRootPaths.splice(fromIndex, 1); + if (!movedPath) return; + nextLocalRootPaths.splice(toIndex, 0, movedPath); + const updated = await electronAPI.setAgentFilesystemSettings({ + mode: "desktop_local_folder", + localRootPaths: nextLocalRootPaths, + }); + setFilesystemSettings(updated); + }, + [electronAPI, localRootPaths] + ); + const runPickLocalRoot = useCallback(async () => { if (!electronAPI?.pickAgentFilesystemRoot) return; const picked = await electronAPI.pickAgentFilesystemRoot(); @@ -1208,16 +1229,47 @@ function AuthenticatedDocumentsSidebar({ {localRootPaths.map((rootPath) => ( { - void handleRemoveFilesystemRoot(rootPath); + onSelect={(event) => event.preventDefault()} + draggable + onDragStart={(event) => { + event.dataTransfer.setData("text/plain", rootPath); + event.dataTransfer.effectAllowed = "move"; + setDraggedLocalRootPath(rootPath); }} - className="group h-8 gap-1.5 px-1.5 text-sm text-foreground" + onDragOver={(event) => { + event.preventDefault(); + event.dataTransfer.dropEffect = "move"; + }} + onDrop={(event) => { + event.preventDefault(); + const sourcePath = + event.dataTransfer.getData("text/plain") || draggedLocalRootPath; + if (!sourcePath) return; + void handleReorderFilesystemRoots(sourcePath, rootPath); + setDraggedLocalRootPath(null); + }} + onDragEnd={() => { + setDraggedLocalRootPath(null); + }} + className={`group h-8 gap-1.5 px-1.5 text-sm text-foreground ${ + draggedLocalRootPath === rootPath ? "bg-muted/60" : "" + }`} > {getFolderDisplayName(rootPath)} - + ))} @@ -1358,16 +1410,16 @@ function AuthenticatedDocumentsSidebar({ className="h-5 gap-1 px-1.5 text-[11px] select-none focus-visible:ring-0 focus-visible:ring-offset-0 data-[state=active]:bg-muted-foreground/25 data-[state=active]:text-foreground data-[state=active]:shadow-none" title="Cloud" > - - Cloud + + Cloud - - Local + + Local diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 30e532896..39b8ee769 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -150,6 +150,13 @@ export function LocalFilesystemBrowser({ setMountByRootKey(new Map()); return; } + if (rootPaths.length === 0) { + setMountByRootKey(new Map()); + setMountStatus("complete"); + setMountRefreshInFlight(false); + hasLoadedMountsOnceRef.current = true; + return; + } let cancelled = false; const isInitialMountLoad = !hasLoadedMountsOnceRef.current; if (isInitialMountLoad) { @@ -161,13 +168,9 @@ export function LocalFilesystemBrowser({ .getAgentFilesystemMounts() .then((mounts: LocalRootMount[]) => { if (cancelled) return; - const knownRootKeys = new Set( - rootPaths.map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) - ); const next = new Map(); for (const entry of mounts) { const normalizedRootKey = normalizeRootPathForLookup(entry.rootPath, isWindowsPlatform); - if (!knownRootKeys.has(normalizedRootKey)) continue; next.set(normalizedRootKey, entry.mount); } setMountByRootKey(next); From 6aa172a7306adfdd892285e30a2ca949fb41eec4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 20:07:02 +0530 Subject: [PATCH 178/299] feat(filesystem): increase max local roots to 10, optimize path normalization, and implement caching for filesystem settings --- .../src/modules/agent-filesystem.ts | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index eb8d385b5..a62b84f70 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -11,7 +11,8 @@ export interface AgentFilesystemSettings { } const SETTINGS_FILENAME = "agent-filesystem-settings.json"; -const MAX_LOCAL_ROOTS = 5; +const MAX_LOCAL_ROOTS = 10; +let cachedSettings: AgentFilesystemSettings | null = null; function getSettingsPath(): string { return join(app.getPath("userData"), SETTINGS_FILENAME); @@ -34,7 +35,7 @@ async function canonicalizeRootPath(pathValue: string): Promise { } } -async function normalizeLocalRootPaths(paths: unknown): Promise { +function normalizeLocalRootPaths(paths: unknown): string[] { if (!Array.isArray(paths)) { return []; } @@ -43,8 +44,22 @@ async function normalizeLocalRootPaths(paths: unknown): Promise { if (typeof rawPath !== "string") continue; const trimmed = rawPath.trim(); if (!trimmed) continue; - const canonicalRootPath = await canonicalizeRootPath(trimmed); - uniquePaths.add(canonicalRootPath); + uniquePaths.add(trimmed); + if (uniquePaths.size >= MAX_LOCAL_ROOTS) { + break; + } + } + return [...uniquePaths]; +} + +async function normalizeLocalRootPathsCanonical(paths: unknown): Promise { + const normalizedPaths = normalizeLocalRootPaths(paths); + const canonicalizedPaths = await Promise.all( + normalizedPaths.map((pathValue) => canonicalizeRootPath(pathValue)) + ); + const uniquePaths = new Set(); + for (const canonicalPath of canonicalizedPaths) { + uniquePaths.add(canonicalPath); if (uniquePaths.size >= MAX_LOCAL_ROOTS) { break; } @@ -53,19 +68,26 @@ async function normalizeLocalRootPaths(paths: unknown): Promise { } export async function getAgentFilesystemSettings(): Promise { + if (cachedSettings) { + return cachedSettings; + } try { const raw = await readFile(getSettingsPath(), "utf8"); const parsed = JSON.parse(raw) as Partial; if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") { - return getDefaultSettings(); + cachedSettings = getDefaultSettings(); + return cachedSettings; } - return { + cachedSettings = { mode: parsed.mode, - localRootPaths: await normalizeLocalRootPaths(parsed.localRootPaths), + // Avoid filesystem I/O during reads; canonicalize paths on write. + localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths), updatedAt: parsed.updatedAt ?? new Date().toISOString(), }; + return cachedSettings; } catch { - return getDefaultSettings(); + cachedSettings = getDefaultSettings(); + return cachedSettings; } } @@ -85,13 +107,14 @@ export async function setAgentFilesystemSettings( localRootPaths: settings.localRootPaths === undefined ? current.localRootPaths - : await normalizeLocalRootPaths(settings.localRootPaths ?? []), + : await normalizeLocalRootPathsCanonical(settings.localRootPaths ?? []), updatedAt: new Date().toISOString(), }; const settingsPath = getSettingsPath(); await mkdir(dirname(settingsPath), { recursive: true }); await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8"); + cachedSettings = next; return next; } From 86e2dc8a5dc7cde161e748c6df2ee42c315caf9e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 20:20:14 +0530 Subject: [PATCH 179/299] refactor(filesystem): remove unused drag-and-drop functionality in DocumentsSidebar --- .../layout/ui/sidebar/DocumentsSidebar.tsx | 46 +------------------ 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index fc8db0386..990a4eb99 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -205,7 +205,6 @@ function AuthenticatedDocumentsSidebar({ const [filesystemSettings, setFilesystemSettings] = useState(null); const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); const [pendingLocalPath, setPendingLocalPath] = useState(null); - const [draggedLocalRootPath, setDraggedLocalRootPath] = useState(null); const [watchedFolderIds, setWatchedFolderIds] = useState>(new Set()); const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom); const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom); @@ -260,26 +259,6 @@ function AuthenticatedDocumentsSidebar({ [electronAPI, localRootPaths] ); - const handleReorderFilesystemRoots = useCallback( - async (draggedPath: string, targetPath: string) => { - if (!electronAPI?.setAgentFilesystemSettings) return; - if (draggedPath === targetPath) return; - const fromIndex = localRootPaths.indexOf(draggedPath); - const toIndex = localRootPaths.indexOf(targetPath); - if (fromIndex < 0 || toIndex < 0) return; - const nextLocalRootPaths = [...localRootPaths]; - const [movedPath] = nextLocalRootPaths.splice(fromIndex, 1); - if (!movedPath) return; - nextLocalRootPaths.splice(toIndex, 0, movedPath); - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: nextLocalRootPaths, - }); - setFilesystemSettings(updated); - }, - [electronAPI, localRootPaths] - ); - const runPickLocalRoot = useCallback(async () => { if (!electronAPI?.pickAgentFilesystemRoot) return; const picked = await electronAPI.pickAgentFilesystemRoot(); @@ -1230,30 +1209,7 @@ function AuthenticatedDocumentsSidebar({ event.preventDefault()} - draggable - onDragStart={(event) => { - event.dataTransfer.setData("text/plain", rootPath); - event.dataTransfer.effectAllowed = "move"; - setDraggedLocalRootPath(rootPath); - }} - onDragOver={(event) => { - event.preventDefault(); - event.dataTransfer.dropEffect = "move"; - }} - onDrop={(event) => { - event.preventDefault(); - const sourcePath = - event.dataTransfer.getData("text/plain") || draggedLocalRootPath; - if (!sourcePath) return; - void handleReorderFilesystemRoots(sourcePath, rootPath); - setDraggedLocalRootPath(null); - }} - onDragEnd={() => { - setDraggedLocalRootPath(null); - }} - className={`group h-8 gap-1.5 px-1.5 text-sm text-foreground ${ - draggedLocalRootPath === rootPath ? "bg-muted/60" : "" - }`} + className="group h-8 gap-1.5 px-1.5 text-sm text-foreground" > From 27e16231c1a548471624724c22c6001d8d91fc2d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 21:00:40 +0530 Subject: [PATCH 180/299] feat(filesystem): enhance agent filesystem API with searchSpaceId support for improved context handling --- surfsense_desktop/src/ipc/channels.ts | 1 + surfsense_desktop/src/ipc/handlers.ts | 45 +++- .../src/modules/agent-filesystem.ts | 252 +++++++++++++++--- surfsense_desktop/src/preload.ts | 25 +- .../new-chat/[[...chat_id]]/page.tsx | 6 +- .../components/editor-panel/editor-panel.tsx | 10 +- .../layout/ui/sidebar/DocumentsSidebar.tsx | 21 +- .../ui/sidebar/LocalFilesystemBrowser.tsx | 42 ++- surfsense_web/lib/agent-filesystem.ts | 8 +- surfsense_web/types/window.d.ts | 24 +- 10 files changed, 349 insertions(+), 85 deletions(-) diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index ccd166899..ec676fba8 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -56,6 +56,7 @@ export const IPC_CHANNELS = { // Agent filesystem mode AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings', AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts', + AGENT_FILESYSTEM_LIST_FILES: 'agent-filesystem:list-files', AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings', AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root', } as const; diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 54882f4ee..4054255f4 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -37,6 +37,7 @@ import { trackEvent, } from '../modules/analytics'; import { + listAgentFilesystemFiles, readAgentLocalFileText, writeAgentLocalFileText, getAgentFilesystemMounts, @@ -126,21 +127,24 @@ export function registerIpcHandlers(): void { readLocalFiles(paths) ); - ipcMain.handle(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, async (_event, virtualPath: string) => { + ipcMain.handle( + IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, + async (_event, virtualPath: string, searchSpaceId?: number | null) => { try { - const result = await readAgentLocalFileText(virtualPath); + const result = await readAgentLocalFileText(virtualPath, searchSpaceId); return { ok: true, path: result.path, content: result.content }; } catch (error) { const message = error instanceof Error ? error.message : 'Failed to read local file'; return { ok: false, path: virtualPath, error: message }; } - }); + } + ); ipcMain.handle( IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, - async (_event, virtualPath: string, content: string) => { + async (_event, virtualPath: string, content: string, searchSpaceId?: number | null) => { try { - const result = await writeAgentLocalFileText(virtualPath, content); + const result = await writeAgentLocalFileText(virtualPath, content, searchSpaceId); return { ok: true, path: result.path }; } catch (error) { const message = error instanceof Error ? error.message : 'Failed to write local file'; @@ -223,18 +227,37 @@ export function registerIpcHandlers(): void { }; }); - ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, () => - getAgentFilesystemSettings() + ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, (_event, searchSpaceId?: number | null) => + getAgentFilesystemSettings(searchSpaceId) ); - ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, () => - getAgentFilesystemMounts() + ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, (_event, searchSpaceId?: number | null) => + getAgentFilesystemMounts(searchSpaceId) + ); + + ipcMain.handle( + IPC_CHANNELS.AGENT_FILESYSTEM_LIST_FILES, + ( + _event, + options: { + rootPath: string; + searchSpaceId?: number | null; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; + } + ) => + listAgentFilesystemFiles(options) ); ipcMain.handle( IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, - (_event, settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }) => - setAgentFilesystemSettings(settings) + ( + _event, + payload: { + searchSpaceId?: number | null; + settings: { mode?: 'cloud' | 'desktop_local_folder'; localRootPaths?: string[] | null }; + } + ) => setAgentFilesystemSettings(payload?.searchSpaceId, payload?.settings ?? {}) ); ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () => diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index a62b84f70..d8c64b79a 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -1,6 +1,7 @@ import { app, dialog } from "electron"; -import { access, mkdir, readFile, realpath, writeFile } from "node:fs/promises"; -import { dirname, isAbsolute, join, relative, resolve } from "node:path"; +import type { Dirent } from "node:fs"; +import { access, mkdir, readdir, readFile, realpath, stat, writeFile } from "node:fs/promises"; +import { dirname, extname, isAbsolute, join, relative, resolve } from "node:path"; export type AgentFilesystemMode = "cloud" | "desktop_local_folder"; @@ -10,9 +11,15 @@ export interface AgentFilesystemSettings { updatedAt: string; } +type AgentFilesystemSettingsStore = { + version: 2; + spaces: Record; +}; + const SETTINGS_FILENAME = "agent-filesystem-settings.json"; const MAX_LOCAL_ROOTS = 10; -let cachedSettings: AgentFilesystemSettings | null = null; +const DEFAULT_SPACE_KEY = "default"; +let cachedSettingsStore: AgentFilesystemSettingsStore | null = null; function getSettingsPath(): string { return join(app.getPath("userData"), SETTINGS_FILENAME); @@ -67,37 +74,97 @@ async function normalizeLocalRootPathsCanonical(paths: unknown): Promise { - if (cachedSettings) { - return cachedSettings; +function normalizeSearchSpaceKey(searchSpaceId?: number | null): string { + if (typeof searchSpaceId === "number" && Number.isFinite(searchSpaceId) && searchSpaceId > 0) { + return String(searchSpaceId); } + return DEFAULT_SPACE_KEY; +} + +function toSettingsFromUnknown(value: unknown): AgentFilesystemSettings | null { + if (!value || typeof value !== "object") { + return null; + } + const parsed = value as Partial; + if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") { + return null; + } + return { + mode: parsed.mode, + localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths), + updatedAt: parsed.updatedAt ?? new Date().toISOString(), + }; +} + +function getDefaultStore(): AgentFilesystemSettingsStore { + return { version: 2, spaces: {} }; +} + +function getSettingsFromStore( + store: AgentFilesystemSettingsStore, + searchSpaceId?: number | null +): AgentFilesystemSettings { + const key = normalizeSearchSpaceKey(searchSpaceId); + return store.spaces[key] ?? getDefaultSettings(); +} + +async function loadAgentFilesystemSettingsStore(): Promise { + if (cachedSettingsStore) { + return cachedSettingsStore; + } + const settingsPath = getSettingsPath(); try { - const raw = await readFile(getSettingsPath(), "utf8"); - const parsed = JSON.parse(raw) as Partial; - if (parsed.mode !== "cloud" && parsed.mode !== "desktop_local_folder") { - cachedSettings = getDefaultSettings(); - return cachedSettings; + const raw = await readFile(settingsPath, "utf8"); + const parsed = JSON.parse(raw) as unknown; + const nextStore = getDefaultStore(); + if ( + parsed && + typeof parsed === "object" && + "version" in parsed && + "spaces" in parsed && + (parsed as { version?: unknown }).version === 2 + ) { + const parsedStore = parsed as { spaces?: Record; version: 2 }; + if (parsedStore.spaces && typeof parsedStore.spaces === "object") { + for (const [spaceKey, rawSettings] of Object.entries(parsedStore.spaces)) { + const normalizedSettings = toSettingsFromUnknown(rawSettings); + if (normalizedSettings) { + nextStore.spaces[String(spaceKey)] = normalizedSettings; + } + } + } + } else { + // Strict migration: reject legacy/non-scoped settings and reset. + await mkdir(dirname(settingsPath), { recursive: true }); + await writeFile(settingsPath, JSON.stringify(nextStore, null, 2), "utf8"); } - cachedSettings = { - mode: parsed.mode, - // Avoid filesystem I/O during reads; canonicalize paths on write. - localRootPaths: normalizeLocalRootPaths(parsed.localRootPaths), - updatedAt: parsed.updatedAt ?? new Date().toISOString(), - }; - return cachedSettings; + cachedSettingsStore = nextStore; + return nextStore; } catch { - cachedSettings = getDefaultSettings(); - return cachedSettings; + cachedSettingsStore = getDefaultStore(); + await mkdir(dirname(settingsPath), { recursive: true }); + await writeFile(settingsPath, JSON.stringify(cachedSettingsStore, null, 2), "utf8"); + return cachedSettingsStore; } } +export async function getAgentFilesystemSettings( + searchSpaceId?: number | null +): Promise { + const store = await loadAgentFilesystemSettingsStore(); + return getSettingsFromStore(store, searchSpaceId); +} + export async function setAgentFilesystemSettings( + searchSpaceId: number | null | undefined, settings: { mode?: AgentFilesystemMode; localRootPaths?: string[] | null; } ): Promise { - const current = await getAgentFilesystemSettings(); + const store = await loadAgentFilesystemSettingsStore(); + const key = normalizeSearchSpaceKey(searchSpaceId); + const current = getSettingsFromStore(store, searchSpaceId); const nextMode = settings.mode === "cloud" || settings.mode === "desktop_local_folder" ? settings.mode @@ -113,8 +180,15 @@ export async function setAgentFilesystemSettings( const settingsPath = getSettingsPath(); await mkdir(dirname(settingsPath), { recursive: true }); - await writeFile(settingsPath, JSON.stringify(next, null, 2), "utf8"); - cachedSettings = next; + const nextStore: AgentFilesystemSettingsStore = { + version: 2, + spaces: { + ...store.spaces, + [key]: next, + }, + }; + await writeFile(settingsPath, JSON.stringify(nextStore, null, 2), "utf8"); + cachedSettingsStore = nextStore; return next; } @@ -160,6 +234,20 @@ export type LocalRootMount = { rootPath: string; }; +export type AgentFilesystemListOptions = { + rootPath: string; + searchSpaceId?: number | null; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; +}; + +export type AgentFilesystemFileEntry = { + relativePath: string; + fullPath: string; + size: number; + mtimeMs: number; +}; + function sanitizeMountName(rawMount: string): string { const normalized = rawMount .trim() @@ -188,11 +276,111 @@ function buildRootMounts(rootPaths: string[]): LocalRootMount[] { return mounts; } -export async function getAgentFilesystemMounts(): Promise { - const rootPaths = await resolveCurrentRootPaths(); +export async function getAgentFilesystemMounts( + searchSpaceId?: number | null +): Promise { + const rootPaths = await resolveCurrentRootPaths(searchSpaceId); return buildRootMounts(rootPaths); } +function normalizeComparablePath(pathValue: string): string { + const normalized = resolve(pathValue); + return process.platform === "win32" ? normalized.toLowerCase() : normalized; +} + +function normalizeExtensionSet(fileExtensions: string[] | null | undefined): Set | null { + if (!fileExtensions || fileExtensions.length === 0) { + return null; + } + const set = new Set(); + for (const extension of fileExtensions) { + if (typeof extension !== "string") continue; + const trimmed = extension.trim().toLowerCase(); + if (!trimmed) continue; + set.add(trimmed.startsWith(".") ? trimmed : `.${trimmed}`); + } + return set.size > 0 ? set : null; +} + +function normalizeExcludeSet(excludePatterns: string[] | null | undefined): Set { + const set = new Set(); + for (const pattern of excludePatterns ?? []) { + if (typeof pattern !== "string") continue; + const trimmed = pattern.trim(); + if (!trimmed) continue; + set.add(trimmed); + } + return set; +} + +export async function listAgentFilesystemFiles( + options: AgentFilesystemListOptions +): Promise { + const allowedRootPaths = await resolveCurrentRootPaths(options.searchSpaceId); + const requestedRootPath = await canonicalizeRootPath(options.rootPath); + const normalizedRequestedRoot = normalizeComparablePath(requestedRootPath); + const allowedRoots = new Set( + ( + await Promise.all(allowedRootPaths.map((rootPath) => canonicalizeRootPath(rootPath))) + ).map((rootPath) => normalizeComparablePath(rootPath)) + ); + if (!allowedRoots.has(normalizedRequestedRoot)) { + throw new Error("Selected path is not an allowed local root"); + } + + const excludePatterns = normalizeExcludeSet(options.excludePatterns); + const extensionSet = normalizeExtensionSet(options.fileExtensions); + const files: AgentFilesystemFileEntry[] = []; + const stack: string[] = [requestedRootPath]; + + while (stack.length > 0) { + const currentDir = stack.pop(); + if (!currentDir) continue; + let entries: Dirent[]; + try { + entries = await readdir(currentDir, { withFileTypes: true }); + } catch { + continue; + } + + for (const entry of entries) { + if (entry.name.startsWith(".") || excludePatterns.has(entry.name)) { + continue; + } + const absolutePath = join(currentDir, entry.name); + if (entry.isDirectory()) { + stack.push(absolutePath); + continue; + } + if (!entry.isFile()) { + continue; + } + if (extensionSet) { + const extension = extname(entry.name).toLowerCase(); + if (!extensionSet.has(extension)) { + continue; + } + } + try { + const fileStat = await stat(absolutePath); + if (!fileStat.isFile()) { + continue; + } + files.push({ + relativePath: relative(requestedRootPath, absolutePath).replace(/\\/g, "/"), + fullPath: absolutePath, + size: fileStat.size, + mtimeMs: fileStat.mtimeMs, + }); + } catch { + // Files can disappear while scanning. + } + } + } + + return files; +} + function parseMountedVirtualPath( virtualPath: string, mounts: LocalRootMount[] @@ -231,8 +419,8 @@ function toMountedVirtualPath(mount: string, rootPath: string, absolutePath: str return `/${mount}${relativePath}`; } -async function resolveCurrentRootPaths(): Promise { - const settings = await getAgentFilesystemSettings(); +async function resolveCurrentRootPaths(searchSpaceId?: number | null): Promise { + const settings = await getAgentFilesystemSettings(searchSpaceId); if (settings.localRootPaths.length === 0) { throw new Error("No local filesystem roots selected"); } @@ -240,9 +428,10 @@ async function resolveCurrentRootPaths(): Promise { } export async function readAgentLocalFileText( - virtualPath: string + virtualPath: string, + searchSpaceId?: number | null ): Promise<{ path: string; content: string }> { - const rootPaths = await resolveCurrentRootPaths(); + const rootPaths = await resolveCurrentRootPaths(searchSpaceId); const mounts = buildRootMounts(rootPaths); const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts); const rootMount = findMountByName(mounts, mount); @@ -261,9 +450,10 @@ export async function readAgentLocalFileText( export async function writeAgentLocalFileText( virtualPath: string, - content: string + content: string, + searchSpaceId?: number | null ): Promise<{ path: string }> { - const rootPaths = await resolveCurrentRootPaths(); + const rootPaths = await resolveCurrentRootPaths(searchSpaceId); const mounts = buildRootMounts(rootPaths); const { mount, subPath } = parseMountedVirtualPath(virtualPath, mounts); const rootMount = findMountByName(mounts, mount); diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index 9c538f691..8e5c2f56b 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -71,10 +71,10 @@ contextBridge.exposeInMainWorld('electronAPI', { // Browse files via native dialog browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES), readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths), - readAgentLocalFileText: (virtualPath: string) => - ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath), - writeAgentLocalFileText: (virtualPath: string, content: string) => - ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content), + readAgentLocalFileText: (virtualPath: string, searchSpaceId?: number | null) => + ipcRenderer.invoke(IPC_CHANNELS.READ_AGENT_LOCAL_FILE_TEXT, virtualPath, searchSpaceId), + writeAgentLocalFileText: (virtualPath: string, content: string, searchSpaceId?: number | null) => + ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content, searchSpaceId), // Auth token sync across windows getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS), @@ -106,13 +106,20 @@ contextBridge.exposeInMainWorld('electronAPI', { ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_CAPTURE, { event, properties }), getAnalyticsContext: () => ipcRenderer.invoke(IPC_CHANNELS.ANALYTICS_GET_CONTEXT), // Agent filesystem mode - getAgentFilesystemSettings: () => - ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS), - getAgentFilesystemMounts: () => - ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS), + getAgentFilesystemSettings: (searchSpaceId?: number | null) => + ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_SETTINGS, searchSpaceId), + getAgentFilesystemMounts: (searchSpaceId?: number | null) => + ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_GET_MOUNTS, searchSpaceId), + listAgentFilesystemFiles: (options: { + rootPath: string; + searchSpaceId?: number | null; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; + }) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_LIST_FILES, options), setAgentFilesystemSettings: (settings: { mode?: "cloud" | "desktop_local_folder"; localRootPaths?: string[] | null; - }) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, settings), + }, searchSpaceId?: number | null) => + ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_SET_SETTINGS, { searchSpaceId, settings }), pickAgentFilesystemRoot: () => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT), }); diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 62332d2c4..06f3bf79f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -658,7 +658,7 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(); + const selection = await getAgentFilesystemSelection(searchSpaceId); if ( selection.filesystem_mode === "desktop_local_folder" && (!selection.local_filesystem_mounts || @@ -1088,7 +1088,7 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(); + const selection = await getAgentFilesystemSelection(searchSpaceId); const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { method: "POST", headers: { @@ -1424,7 +1424,7 @@ export default function NewChatPage() { ]); try { - const selection = await getAgentFilesystemSelection(); + const selection = await getAgentFilesystemSelection(searchSpaceId); const response = await fetch(getRegenerateUrl(threadId), { method: "POST", headers: { diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 1f1b41c3e..a9fe886e1 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -124,7 +124,10 @@ export function EditorPanelContent({ if (!electronAPI?.readAgentLocalFileText) { throw new Error("Local file editor is available only in desktop mode."); } - const readResult = await electronAPI.readAgentLocalFileText(localFilePath); + const readResult = await electronAPI.readAgentLocalFileText( + localFilePath, + searchSpaceId + ); if (!readResult.ok) { throw new Error(readResult.error || "Failed to read local file"); } @@ -226,7 +229,7 @@ export function EditorPanelContent({ } }, [editorDoc?.source_markdown]); - const handleSave = useCallback(async (options?: { silent?: boolean }) => { + const handleSave = useCallback(async (_options?: { silent?: boolean }) => { setSaving(true); try { if (isLocalFileMode) { @@ -239,7 +242,8 @@ export function EditorPanelContent({ const contentToSave = markdownRef.current; const writeResult = await electronAPI.writeAgentLocalFileText( localFilePath, - contentToSave + contentToSave, + searchSpaceId ); if (!writeResult.ok) { throw new Error(writeResult.error || "Failed to save local file"); diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 990a4eb99..3b747b15a 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -214,7 +214,7 @@ function AuthenticatedDocumentsSidebar({ if (!electronAPI?.getAgentFilesystemSettings) return; let mounted = true; electronAPI - .getAgentFilesystemSettings() + .getAgentFilesystemSettings(searchSpaceId) .then((settings: FilesystemSettings) => { if (!mounted) return; setFilesystemSettings(settings); @@ -230,7 +230,7 @@ function AuthenticatedDocumentsSidebar({ return () => { mounted = false; }; - }, [electronAPI]); + }, [electronAPI, searchSpaceId]); const hasLocalFilesystemTrust = useCallback(() => { try { @@ -253,10 +253,10 @@ function AuthenticatedDocumentsSidebar({ const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder", localRootPaths: nextLocalRootPaths, - }); + }, searchSpaceId); setFilesystemSettings(updated); }, - [electronAPI, localRootPaths] + [electronAPI, localRootPaths, searchSpaceId] ); const runPickLocalRoot = useCallback(async () => { @@ -285,10 +285,10 @@ function AuthenticatedDocumentsSidebar({ const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder", localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove), - }); + }, searchSpaceId); setFilesystemSettings(updated); }, - [electronAPI, localRootPaths] + [electronAPI, localRootPaths, searchSpaceId] ); const handleClearFilesystemRoots = useCallback(async () => { @@ -296,19 +296,19 @@ function AuthenticatedDocumentsSidebar({ const updated = await electronAPI.setAgentFilesystemSettings({ mode: "desktop_local_folder", localRootPaths: [], - }); + }, searchSpaceId); setFilesystemSettings(updated); - }, [electronAPI]); + }, [electronAPI, searchSpaceId]); const handleFilesystemTabChange = useCallback( async (tab: "cloud" | "local") => { if (!electronAPI?.setAgentFilesystemSettings) return; const updated = await electronAPI.setAgentFilesystemSettings({ mode: tab === "cloud" ? "cloud" : "desktop_local_folder", - }); + }, searchSpaceId); setFilesystemSettings(updated); }, - [electronAPI] + [electronAPI, searchSpaceId] ); // AI File Sort state @@ -1323,6 +1323,7 @@ function AuthenticatedDocumentsSidebar({ { openEditorPanel({ diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 39b8ee769..add7cd2d9 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -11,6 +11,7 @@ import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; interface LocalFilesystemBrowserProps { rootPaths: string[]; searchSpaceId: number; + active?: boolean; searchQuery?: string; onOpenFile: (fullPath: string) => void; } @@ -75,6 +76,7 @@ function toMountedVirtualPath(mount: string, relativePath: string): string { export function LocalFilesystemBrowser({ rootPaths, searchSpaceId, + active = true, searchQuery, onOpenFile, }: LocalFilesystemBrowserProps) { @@ -84,13 +86,36 @@ export function LocalFilesystemBrowser({ const [mountByRootKey, setMountByRootKey] = useState>(new Map()); const [mountStatus, setMountStatus] = useState("idle"); const [mountRefreshInFlight, setMountRefreshInFlight] = useState(false); + const lastLoadedRootsSignatureRef = useRef(""); const hasLoadedMountsOnceRef = useRef(false); const hasResolvedAtLeastOneRootRef = useRef(false); const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); const isWindowsPlatform = electronAPI?.versions.platform === "win32"; useEffect(() => { - if (!electronAPI?.listFolderFiles) return; + if (!active) return; + if (!electronAPI?.listAgentFilesystemFiles) { + for (const rootPath of rootPaths) { + setRootStateMap((prev) => ({ + ...prev, + [rootPath]: { + loading: false, + error: "Desktop app update required for local mode browsing.", + files: [], + }, + })); + } + return; + } + const rootsSignature = rootPaths + .map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) + .sort() + .join("|"); + const settingsSignature = `${searchSpaceId}:${rootsSignature}`; + if (settingsSignature === lastLoadedRootsSignatureRef.current) { + return; + } + lastLoadedRootsSignatureRef.current = settingsSignature; let cancelled = false; for (const rootPath of rootPaths) { @@ -107,14 +132,11 @@ export function LocalFilesystemBrowser({ void Promise.all( rootPaths.map(async (rootPath) => { try { - const files = (await electronAPI.listFolderFiles({ - path: rootPath, - name: getFolderDisplayName(rootPath), + const files = (await electronAPI.listAgentFilesystemFiles({ + rootPath, + searchSpaceId, excludePatterns: DEFAULT_EXCLUDE_PATTERNS, fileExtensions: supportedExtensions, - rootFolderId: null, - searchSpaceId, - active: true, })) as LocalFolderFileEntry[]; if (cancelled) return; setRootStateMap((prev) => ({ @@ -142,7 +164,7 @@ export function LocalFilesystemBrowser({ return () => { cancelled = true; }; - }, [electronAPI, rootPaths, searchSpaceId, supportedExtensions]); + }, [active, electronAPI, isWindowsPlatform, rootPaths, searchSpaceId, supportedExtensions]); useEffect(() => { if (!electronAPI?.getAgentFilesystemMounts) { @@ -165,7 +187,7 @@ export function LocalFilesystemBrowser({ setMountRefreshInFlight(true); } void electronAPI - .getAgentFilesystemMounts() + .getAgentFilesystemMounts(searchSpaceId) .then((mounts: LocalRootMount[]) => { if (cancelled) return; const next = new Map(); @@ -191,7 +213,7 @@ export function LocalFilesystemBrowser({ return () => { cancelled = true; }; - }, [electronAPI, isWindowsPlatform, rootPaths]); + }, [electronAPI, isWindowsPlatform, rootPaths, searchSpaceId]); const treeByRoot = useMemo(() => { const query = searchQuery?.trim().toLowerCase() ?? ""; diff --git a/surfsense_web/lib/agent-filesystem.ts b/surfsense_web/lib/agent-filesystem.ts index 91c366d43..da5fc1b1d 100644 --- a/surfsense_web/lib/agent-filesystem.ts +++ b/surfsense_web/lib/agent-filesystem.ts @@ -22,15 +22,17 @@ export function getClientPlatform(): ClientPlatform { return window.electronAPI ? "desktop" : "web"; } -export async function getAgentFilesystemSelection(): Promise { +export async function getAgentFilesystemSelection( + searchSpaceId?: number | null +): Promise { const platform = getClientPlatform(); if (platform !== "desktop" || !window.electronAPI?.getAgentFilesystemSettings) { return { ...DEFAULT_SELECTION, client_platform: platform }; } try { - const settings = await window.electronAPI.getAgentFilesystemSettings(); + const settings = await window.electronAPI.getAgentFilesystemSettings(searchSpaceId); if (settings.mode === "desktop_local_folder") { - const mounts = await window.electronAPI.getAgentFilesystemMounts?.(); + const mounts = await window.electronAPI.getAgentFilesystemMounts?.(searchSpaceId); const localFilesystemMounts = mounts?.map((entry) => ({ mount_id: entry.mount, diff --git a/surfsense_web/types/window.d.ts b/surfsense_web/types/window.d.ts index e9f29a8f3..d3356d4d1 100644 --- a/surfsense_web/types/window.d.ts +++ b/surfsense_web/types/window.d.ts @@ -54,6 +54,13 @@ interface AgentFilesystemMount { rootPath: string; } +interface AgentFilesystemListOptions { + rootPath: string; + searchSpaceId?: number | null; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; +} + interface LocalTextFileResult { ok: boolean; path: string; @@ -114,10 +121,14 @@ interface ElectronAPI { // Browse files/folders via native dialogs browseFiles: () => Promise; readLocalFiles: (paths: string[]) => Promise; - readAgentLocalFileText: (virtualPath: string) => Promise; + readAgentLocalFileText: ( + virtualPath: string, + searchSpaceId?: number | null + ) => Promise; writeAgentLocalFileText: ( virtualPath: string, - content: string + content: string, + searchSpaceId?: number | null ) => Promise; // Auth token sync across windows getAuthTokens: () => Promise<{ bearer: string; refresh: string } | null>; @@ -151,12 +162,15 @@ interface ElectronAPI { platform: string; }>; // Agent filesystem mode - getAgentFilesystemSettings: () => Promise; - getAgentFilesystemMounts: () => Promise; + getAgentFilesystemSettings: (searchSpaceId?: number | null) => Promise; + getAgentFilesystemMounts: (searchSpaceId?: number | null) => Promise; + listAgentFilesystemFiles: ( + options: AgentFilesystemListOptions + ) => Promise; setAgentFilesystemSettings: (settings: { mode?: AgentFilesystemMode; localRootPaths?: string[] | null; - }) => Promise; + }, searchSpaceId?: number | null) => Promise; pickAgentFilesystemRoot: () => Promise; } From 1190ee9449626f921af250ce2d88969cca7f6a9a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 21:17:47 +0530 Subject: [PATCH 181/299] feat(sidebar): separate DesktopLocalTabContent component for reducing bundle size in web --- .../ui/sidebar/DesktopLocalTabContent.tsx | 187 ++++++++++++++ .../layout/ui/sidebar/DocumentsSidebar.tsx | 229 ++++-------------- 2 files changed, 233 insertions(+), 183 deletions(-) create mode 100644 surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx diff --git a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx new file mode 100644 index 000000000..6fd4e48f8 --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx @@ -0,0 +1,187 @@ +"use client"; + +import { Folder, FolderPlus, Search, X } from "lucide-react"; +import { useRef, useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Separator } from "@/components/ui/separator"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { useDebouncedValue } from "@/hooks/use-debounced-value"; +import { LocalFilesystemBrowser } from "./LocalFilesystemBrowser"; + +const getFolderDisplayName = (rootPath: string): string => + rootPath.split(/[\\/]/).at(-1) || rootPath; + +interface DesktopLocalTabContentProps { + localRootPaths: string[]; + canAddMoreLocalRoots: boolean; + maxLocalFilesystemRoots: number; + searchSpaceId: number; + onPickFilesystemRoot: () => Promise | void; + onRemoveFilesystemRoot: (rootPath: string) => Promise | void; + onClearFilesystemRoots: () => Promise | void; + onOpenLocalFile: (localFilePath: string) => void; + electronAvailable: boolean; +} + +export function DesktopLocalTabContent({ + localRootPaths, + canAddMoreLocalRoots, + maxLocalFilesystemRoots, + searchSpaceId, + onPickFilesystemRoot, + onRemoveFilesystemRoot, + onClearFilesystemRoots, + onOpenLocalFile, + electronAvailable, +}: DesktopLocalTabContentProps) { + const [localSearch, setLocalSearch] = useState(""); + const debouncedLocalSearch = useDebouncedValue(localSearch, 250); + const localSearchInputRef = useRef(null); + + return ( +
+
+
+ {localRootPaths.length > 0 ? ( + + + + + + + Selected folders + + + {localRootPaths.map((rootPath) => ( + event.preventDefault()} + className="group h-8 gap-1.5 px-1.5 text-sm text-foreground" + > + + + {getFolderDisplayName(rootPath)} + + + + ))} + + { + void onClearFilesystemRoots(); + }} + > + Clear all folders + + + + ) : ( +
+ + No local folders selected +
+ )} + + {electronAvailable ? ( + + + + + + + + {canAddMoreLocalRoots + ? "Add folder" + : `You can add up to ${maxLocalFilesystemRoots} folders`} + + + ) : null} +
+
+
+
+
+
+ setLocalSearch(e.target.value)} + placeholder="Search local files" + type="text" + aria-label="Search local files" + /> + {Boolean(localSearch) && ( + + )} +
+
+ +
+ ); +} diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 3b747b15a..5b9157b28 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -6,19 +6,17 @@ import { ChevronLeft, ChevronRight, FileText, - Folder, - FolderPlus, FolderClock, Laptop, Lock, Paperclip, - Search, Server, Trash2, Unplug, Upload, X, } from "lucide-react"; +import dynamic from "next/dynamic"; import Link from "next/link"; import { useParams } from "next/navigation"; import { useTranslations } from "next-intl"; @@ -49,7 +47,6 @@ import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems"; import { DEFAULT_EXCLUDE_PATTERNS, FolderWatchDialog, - type SelectedFolder, } from "@/components/sources/FolderWatchDialog"; import { AlertDialog, @@ -63,18 +60,8 @@ import { } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuLabel, - DropdownMenuSeparator, - DropdownMenuTrigger, -} from "@/components/ui/dropdown-menu"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; import { Skeleton } from "@/components/ui/skeleton"; -import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; @@ -84,7 +71,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useMediaQuery } from "@/hooks/use-media-query"; -import { useElectronAPI } from "@/hooks/use-platform"; +import { usePlatform, useElectronAPI } from "@/hooks/use-platform"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service"; @@ -93,9 +80,13 @@ import { authenticatedFetch } from "@/lib/auth-utils"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; -import { LocalFilesystemBrowser } from "./LocalFilesystemBrowser"; import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel"; +const DesktopLocalTabContent = dynamic( + () => import("./DesktopLocalTabContent").then((mod) => mod.DesktopLocalTabContent), + { ssr: false } +); + const NON_DELETABLE_DOCUMENT_TYPES: readonly string[] = ["SURFSENSE_DOCS"]; const LOCAL_FILESYSTEM_TRUST_KEY = "surfsense.local-filesystem-trust.v1"; const MAX_LOCAL_FILESYSTEM_ROOTS = 10; @@ -142,9 +133,6 @@ interface WatchedFolderEntry { active: boolean; } -const getFolderDisplayName = (rootPath: string): string => - rootPath.split(/[\\/]/).at(-1) || rootPath; - const SHOWCASE_CONNECTORS = [ { type: "GOOGLE_DRIVE_CONNECTOR", label: "Google Drive" }, { type: "GOOGLE_GMAIL_CONNECTOR", label: "Gmail" }, @@ -170,25 +158,40 @@ interface DocumentsSidebarProps { export function DocumentsSidebar(props: DocumentsSidebarProps) { const isAnonymous = useIsAnonymous(); + const { isDesktop } = usePlatform(); if (isAnonymous) { return ; } - return ; + return isDesktop ? ( + + ) : ( + + ); } -function AuthenticatedDocumentsSidebar({ +function AuthenticatedDesktopDocumentsSidebar(props: DocumentsSidebarProps) { + return ; +} + +function AuthenticatedWebDocumentsSidebar(props: DocumentsSidebarProps) { + return ; +} + +function AuthenticatedDocumentsSidebarBase({ open, onOpenChange, isDocked = false, onDockedChange, embedded = false, headerAction, -}: DocumentsSidebarProps) { + desktopFeaturesEnabled, +}: DocumentsSidebarProps & { desktopFeaturesEnabled: boolean }) { const t = useTranslations("documents"); const tSidebar = useTranslations("sidebar"); const params = useParams(); const isMobile = !useMediaQuery("(min-width: 640px)"); - const electronAPI = useElectronAPI(); + const platformElectronAPI = useElectronAPI(); + const electronAPI = desktopFeaturesEnabled ? platformElectronAPI : null; const searchSpaceId = Number(params.search_space_id); const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const setRightPanelCollapsed = useSetAtom(rightPanelCollapsedAtom); @@ -198,9 +201,6 @@ function AuthenticatedDocumentsSidebar({ const [search, setSearch] = useState(""); const debouncedSearch = useDebouncedValue(search, 250); - const [localSearch, setLocalSearch] = useState(""); - const debouncedLocalSearch = useDebouncedValue(localSearch, 250); - const localSearchInputRef = useRef(null); const [activeTypes, setActiveTypes] = useState([]); const [filesystemSettings, setFilesystemSettings] = useState(null); const [localTrustDialogOpen, setLocalTrustDialogOpen] = useState(false); @@ -208,7 +208,7 @@ function AuthenticatedDocumentsSidebar({ const [watchedFolderIds, setWatchedFolderIds] = useState>(new Set()); const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom); const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom); - const isElectron = typeof window !== "undefined" && !!window.electronAPI; + const isElectron = desktopFeaturesEnabled && typeof window !== "undefined" && !!window.electronAPI; useEffect(() => { if (!electronAPI?.getAgentFilesystemSettings) return; @@ -1180,161 +1180,24 @@ function AuthenticatedDocumentsSidebar({ ); const localContent = ( -
-
-
- {localRootPaths.length > 0 ? ( - - - - - - - Selected folders - - - {localRootPaths.map((rootPath) => ( - event.preventDefault()} - className="group h-8 gap-1.5 px-1.5 text-sm text-foreground" - > - - - {getFolderDisplayName(rootPath)} - - - - ))} - - { - void handleClearFilesystemRoots(); - }} - > - Clear all folders - - - - ) : ( -
- - No local folders selected -
- )} - - {electronAPI ? ( - - - - - - - - {canAddMoreLocalRoots - ? "Add folder" - : `You can add up to ${MAX_LOCAL_FILESYSTEM_ROOTS} folders`} - - - ) : ( - - )} -
-
-
-
-
-
- setLocalSearch(e.target.value)} - placeholder="Search local files" - type="text" - aria-label="Search local files" - /> - {Boolean(localSearch) && ( - - )} -
-
- { - openEditorPanel({ - kind: "local_file", - localFilePath, - title: localFilePath.split("/").pop() || localFilePath, - searchSpaceId, - }); - }} - /> -
+ { + openEditorPanel({ + kind: "local_file", + localFilePath, + title: localFilePath.split("/").pop() || localFilePath, + searchSpaceId, + }); + }} + electronAvailable={!!electronAPI} + /> ); const documentsContent = ( @@ -1428,7 +1291,7 @@ function AuthenticatedDocumentsSidebar({ {cloudContent} - {localContent} + {currentFilesystemTab === "local" ? localContent : null} ) : ( From 62b9e328b4b87fe4122eb3d179c314d9a5d964fe Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:15 +0200 Subject: [PATCH 182/299] Add desktop IPC, preload, and window types for chat screen capture and full-screen capture. --- surfsense_desktop/src/ipc/channels.ts | 1 + surfsense_desktop/src/preload.ts | 1 + surfsense_web/types/window.d.ts | 15 ++++++++++++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index 69fb89419..9f084af85 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -11,6 +11,7 @@ export const IPC_CHANNELS = { REQUEST_ACCESSIBILITY: 'request-accessibility', REQUEST_SCREEN_RECORDING: 'request-screen-recording', RESTART_APP: 'restart-app', + CAPTURE_FULL_SCREEN: 'capture-full-screen', SCREEN_REGION_SUBMIT: 'screen-region:submit', SCREEN_REGION_CANCEL: 'screen-region:cancel', CHAT_SCREEN_CAPTURE: 'chat:screen-capture', diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index 087cabd75..7ce2cbcf8 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -32,6 +32,7 @@ contextBridge.exposeInMainWorld('electronAPI', { getPermissionsStatus: () => ipcRenderer.invoke(IPC_CHANNELS.GET_PERMISSIONS_STATUS), requestAccessibility: () => ipcRenderer.invoke(IPC_CHANNELS.REQUEST_ACCESSIBILITY), requestScreenRecording: () => ipcRenderer.invoke(IPC_CHANNELS.REQUEST_SCREEN_RECORDING), + captureFullScreen: () => ipcRenderer.invoke(IPC_CHANNELS.CAPTURE_FULL_SCREEN), restartApp: () => ipcRenderer.invoke(IPC_CHANNELS.RESTART_APP), // Folder sync selectFolder: () => ipcRenderer.invoke(IPC_CHANNELS.FOLDER_SYNC_SELECT_FOLDER), diff --git a/surfsense_web/types/window.d.ts b/surfsense_web/types/window.d.ts index a8f02fd20..ea55743db 100644 --- a/surfsense_web/types/window.d.ts +++ b/surfsense_web/types/window.d.ts @@ -83,6 +83,7 @@ interface ElectronAPI { }>; requestAccessibility: () => Promise; requestScreenRecording: () => Promise; + captureFullScreen: () => Promise; restartApp: () => Promise; // Folder sync selectFolder: () => Promise; @@ -108,10 +109,18 @@ interface ElectronAPI { getAuthTokens: () => Promise<{ bearer: string; refresh: string } | null>; setAuthTokens: (bearer: string, refresh: string) => Promise; // Keyboard shortcut configuration - getShortcuts: () => Promise<{ generalAssist: string; quickAsk: string }>; + getShortcuts: () => Promise<{ + generalAssist: string; + quickAsk: string; + screenshotAssist: string; + }>; setShortcuts: ( - config: Partial<{ generalAssist: string; quickAsk: string }> - ) => Promise<{ generalAssist: string; quickAsk: string }>; + config: Partial<{ generalAssist: string; quickAsk: string; screenshotAssist: string }> + ) => Promise<{ + generalAssist: string; + quickAsk: string; + screenshotAssist: string; + }>; // Launch on system startup getAutoLaunch: () => Promise<{ enabled: boolean; From d212422bf5bea6dcbc7bf3a8c399b6503462b38b Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:20 +0200 Subject: [PATCH 183/299] Add full-screen display capture alongside the region picker for desktop chat. --- .../src/modules/screen-region-picker.ts | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/surfsense_desktop/src/modules/screen-region-picker.ts b/surfsense_desktop/src/modules/screen-region-picker.ts index 0a924eec9..cc9303040 100644 --- a/surfsense_desktop/src/modules/screen-region-picker.ts +++ b/surfsense_desktop/src/modules/screen-region-picker.ts @@ -6,11 +6,13 @@ import { IPC_CHANNELS } from '../ipc/channels'; let pickInProgress = false; -async function captureDisplayDataUrl(display: Electron.Display): Promise<{ +type DisplayCaptureSnapshot = { dataUrl: string; width: number; height: number; -} | null> { +}; + +async function captureDisplaySnapshot(display: Electron.Display): Promise { try { const sf = display.scaleFactor || 1; const tw = Math.max(1, Math.round(display.size.width * sf)); @@ -37,6 +39,12 @@ async function captureDisplayDataUrl(display: Electron.Display): Promise<{ } } +export async function captureCurrentDisplayDataUrl(): Promise { + const display = screen.getDisplayNearestPoint(screen.getCursorScreenPoint()); + const snapshot = await captureDisplaySnapshot(display); + return snapshot?.dataUrl ?? null; +} + function buildInjectScript(dataUrl: string, iw: number, ih: number): string { return `(() => { const api = window.surfsenseScreenRegion; @@ -166,7 +174,7 @@ export function pickScreenRegion(): Promise { resolve(result); }; - let snapshot: { dataUrl: string; width: number; height: number } | null = null; + let snapshot: DisplayCaptureSnapshot | null = null; const onSubmit = ( _event: Electron.IpcMainEvent, @@ -206,7 +214,7 @@ export function pickScreenRegion(): Promise { } }; - void captureDisplayDataUrl(display) + void captureDisplaySnapshot(display) .then((cap) => { if (!cap) { finish(null); From 7145a15149130b16890b1c31e43d3a8372f5f5c8 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:24 +0200 Subject: [PATCH 184/299] Add Screenshot Assist shortcut flow: show window, pick region, send data URL to chat. --- .../src/modules/screenshot-assist.ts | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 surfsense_desktop/src/modules/screenshot-assist.ts diff --git a/surfsense_desktop/src/modules/screenshot-assist.ts b/surfsense_desktop/src/modules/screenshot-assist.ts new file mode 100644 index 000000000..2500bf1d5 --- /dev/null +++ b/surfsense_desktop/src/modules/screenshot-assist.ts @@ -0,0 +1,20 @@ +import { IPC_CHANNELS } from '../ipc/channels'; +import { trackEvent } from './analytics'; +import { pickScreenRegion } from './screen-region-picker'; +import { getMainWindow, showMainWindow } from './window'; +import { hasScreenRecordingPermission, requestScreenRecording } from './permissions'; + +export async function runScreenshotAssistShortcut(): Promise { + showMainWindow('shortcut'); + await new Promise((r) => setTimeout(r, 400)); + if (!hasScreenRecordingPermission()) { + requestScreenRecording(); + return; + } + const url = await pickScreenRegion(); + const mw = getMainWindow(); + if (url && mw && !mw.isDestroyed()) { + mw.webContents.send(IPC_CHANNELS.CHAT_SCREEN_CAPTURE, url); + trackEvent('desktop_screenshot_assist_region_to_chat', {}); + } +} From 24a5a06f215719278e3fde85f9082ec0271ab20c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:27 +0200 Subject: [PATCH 185/299] Make General Assist only focus the main window, without region capture. --- .../src/modules/general-assist.ts | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/surfsense_desktop/src/modules/general-assist.ts b/surfsense_desktop/src/modules/general-assist.ts index 9d39f068a..7d202caa2 100644 --- a/surfsense_desktop/src/modules/general-assist.ts +++ b/surfsense_desktop/src/modules/general-assist.ts @@ -1,21 +1,5 @@ -import { IPC_CHANNELS } from '../ipc/channels'; -import { trackEvent } from './analytics'; -import { pickScreenRegion } from './screen-region-picker'; -import { getMainWindow, showMainWindow } from './window'; -import { hasScreenRecordingPermission, requestScreenRecording } from './permissions'; +import { showMainWindow } from './window'; -export async function runGeneralAssistShortcut(): Promise { - console.log('[general-assist] Shortcut triggered'); +export function runGeneralAssistShortcut(): void { showMainWindow('shortcut'); - await new Promise((r) => setTimeout(r, 400)); - if (!hasScreenRecordingPermission()) { - requestScreenRecording(); - return; - } - const url = await pickScreenRegion(); - const mw = getMainWindow(); - if (url && mw && !mw.isDestroyed()) { - mw.webContents.send(IPC_CHANNELS.CHAT_SCREEN_CAPTURE, url); - trackEvent('desktop_screen_region_to_chat', {}); - } } From f489fee2e8b91abedb6d9a30d08dfd754fa99a30 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:30 +0200 Subject: [PATCH 186/299] Register General Assist and Screenshot Assist as two independent global shortcuts. --- surfsense_desktop/src/modules/tray.ts | 70 ++++++++++++++++++++------- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/surfsense_desktop/src/modules/tray.ts b/surfsense_desktop/src/modules/tray.ts index 97d6146e8..07b53bafb 100644 --- a/surfsense_desktop/src/modules/tray.ts +++ b/surfsense_desktop/src/modules/tray.ts @@ -1,12 +1,14 @@ import { app, globalShortcut, Menu, nativeImage, Tray, type NativeImage } from 'electron'; import path from 'path'; import { runGeneralAssistShortcut } from './general-assist'; +import { runScreenshotAssistShortcut } from './screenshot-assist'; import { showMainWindow } from './window'; import { getShortcuts } from './shortcuts'; import { trackEvent } from './analytics'; let tray: Tray | null = null; -let currentShortcut: string | null = null; +let registeredGeneralAssist: string | null = null; +let registeredScreenshotAssist: string | null = null; function getTrayIcon(): NativeImage { const iconName = process.platform === 'win32' ? 'icon.ico' : 'icon.png'; @@ -17,25 +19,29 @@ function getTrayIcon(): NativeImage { return img.resize({ width: 16, height: 16 }); } -function registerShortcut(accelerator: string): void { - if (currentShortcut) { - globalShortcut.unregister(currentShortcut); - currentShortcut = null; +function registerOne( + previous: string | null, + accelerator: string, + onFire: () => void | Promise, + label: string +): string | null { + if (previous) { + globalShortcut.unregister(previous); } - if (!accelerator) return; + if (!accelerator) return null; try { const ok = globalShortcut.register(accelerator, () => { - void runGeneralAssistShortcut(); + void Promise.resolve(onFire()); }); if (ok) { - currentShortcut = accelerator; - console.log(`[general-assist] Register ${accelerator}: OK`); - } else { - console.warn(`[general-assist] Register ${accelerator}: FAILED (OS or another app may own this chord)`); + console.log(`[hotkeys] Register ${label} ${accelerator}: OK`); + return accelerator; } + console.warn(`[hotkeys] Register ${label} ${accelerator}: FAILED (OS or another app may own this chord)`); } catch (err) { - console.error(`[tray] Error registering General Assist shortcut:`, err); + console.error(`[tray] Error registering ${label} shortcut:`, err); } + return null; } export async function createTray(): Promise { @@ -60,18 +66,48 @@ export async function createTray(): Promise { tray.on('double-click', () => showMainWindow('tray_click')); const shortcuts = await getShortcuts(); - registerShortcut(shortcuts.generalAssist); + registeredGeneralAssist = registerOne( + null, + shortcuts.generalAssist, + runGeneralAssistShortcut, + 'General Assist' + ); + registeredScreenshotAssist = registerOne( + null, + shortcuts.screenshotAssist, + runScreenshotAssistShortcut, + 'Screenshot Assist' + ); } export async function reregisterGeneralAssist(): Promise { const shortcuts = await getShortcuts(); - registerShortcut(shortcuts.generalAssist); + registeredGeneralAssist = registerOne( + registeredGeneralAssist, + shortcuts.generalAssist, + runGeneralAssistShortcut, + 'General Assist' + ); +} + +export async function reregisterScreenshotAssist(): Promise { + const shortcuts = await getShortcuts(); + registeredScreenshotAssist = registerOne( + registeredScreenshotAssist, + shortcuts.screenshotAssist, + runScreenshotAssistShortcut, + 'Screenshot Assist' + ); } export function destroyTray(): void { - if (currentShortcut) { - globalShortcut.unregister(currentShortcut); - currentShortcut = null; + if (registeredGeneralAssist) { + globalShortcut.unregister(registeredGeneralAssist); + registeredGeneralAssist = null; + } + if (registeredScreenshotAssist) { + globalShortcut.unregister(registeredScreenshotAssist); + registeredScreenshotAssist = null; } tray?.destroy(); tray = null; From 1c7362d9c680bd4d1c91e264ee77c1f9dc5b7421 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:33 +0200 Subject: [PATCH 187/299] Handle full-screen capture IPC and reregister Screenshot Assist when its shortcut changes. --- surfsense_desktop/src/ipc/handlers.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 5f55dccf6..8361b9a38 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -2,10 +2,12 @@ import { app, ipcMain, shell } from 'electron'; import { IPC_CHANNELS } from './channels'; import { getPermissionsStatus, + hasScreenRecordingPermission, requestAccessibility, requestScreenRecording, restartApp, } from '../modules/permissions'; +import { captureCurrentDisplayDataUrl } from '../modules/screen-region-picker'; import { selectFolder, addWatchedFolder, @@ -27,7 +29,7 @@ import { getShortcuts, setShortcuts, type ShortcutConfig } from '../modules/shor import { getAutoLaunchState, setAutoLaunch } from '../modules/auto-launch'; import { getActiveSearchSpaceId, setActiveSearchSpaceId } from '../modules/active-search-space'; import { reregisterQuickAsk } from '../modules/quick-ask'; -import { reregisterGeneralAssist } from '../modules/tray'; +import { reregisterGeneralAssist, reregisterScreenshotAssist } from '../modules/tray'; import { getDistinctId, getMachineId, @@ -78,6 +80,14 @@ export function registerIpcHandlers(): void { restartApp(); }); + ipcMain.handle(IPC_CHANNELS.CAPTURE_FULL_SCREEN, async () => { + if (!hasScreenRecordingPermission()) { + requestScreenRecording(); + return null; + } + return captureCurrentDisplayDataUrl(); + }); + // Folder sync handlers ipcMain.handle(IPC_CHANNELS.FOLDER_SYNC_SELECT_FOLDER, () => selectFolder()); @@ -182,6 +192,7 @@ export function registerIpcHandlers(): void { ipcMain.handle(IPC_CHANNELS.SET_SHORTCUTS, async (_event, config: Partial) => { const updated = await setShortcuts(config); if (config.generalAssist) await reregisterGeneralAssist(); + if (config.screenshotAssist) await reregisterScreenshotAssist(); if (config.quickAsk) await reregisterQuickAsk(); trackEvent('desktop_shortcut_updated', { keys: Object.keys(config), From df952ffa2812dbfd66a2207af0e3c81d691a7e9a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:36 +0200 Subject: [PATCH 188/299] Add Screenshot Assist to stored shortcuts, default to Shift+Space, and migrate legacy autocomplete. --- surfsense_desktop/src/modules/shortcuts.ts | 26 +++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/surfsense_desktop/src/modules/shortcuts.ts b/surfsense_desktop/src/modules/shortcuts.ts index 3eb3ca5c9..64687f7db 100644 --- a/surfsense_desktop/src/modules/shortcuts.ts +++ b/surfsense_desktop/src/modules/shortcuts.ts @@ -1,11 +1,13 @@ export interface ShortcutConfig { generalAssist: string; quickAsk: string; + screenshotAssist: string; } const DEFAULTS: ShortcutConfig = { generalAssist: 'CommandOrControl+Shift+S', quickAsk: 'CommandOrControl+Alt+S', + screenshotAssist: 'CommandOrControl+Shift+Space', }; const STORE_KEY = 'shortcuts'; @@ -23,21 +25,25 @@ async function getStore() { return store; } -/** One-time fix if both shortcuts match the mistaken Alt+Shift pair. */ -function wasRegressionAltPair(rest: Record): boolean { - return rest.generalAssist === 'Alt+Shift+G' && rest.quickAsk === 'Alt+Shift+Q'; -} - export async function getShortcuts(): Promise { const s = await getStore(); const raw = (s.get(STORE_KEY) as Record | undefined) ?? {}; + const legacyAutocomplete = raw.autocomplete; const { autocomplete: _drop, ...rest } = raw; - if (wasRegressionAltPair(rest)) { - const fixed = { ...DEFAULTS }; - s.set(STORE_KEY, { ...fixed }); - return fixed; + let merged: ShortcutConfig = { ...DEFAULTS, ...rest }; + if ( + typeof legacyAutocomplete === 'string' && + legacyAutocomplete.length > 0 && + !('screenshotAssist' in raw) + ) { + merged = { ...merged, screenshotAssist: legacyAutocomplete }; + s.set(STORE_KEY, { + generalAssist: merged.generalAssist, + quickAsk: merged.quickAsk, + screenshotAssist: merged.screenshotAssist, + }); } - return { ...DEFAULTS, ...rest }; + return merged; } export async function setShortcuts(config: Partial): Promise { From 97488218db75fba1e6aaf1b4c28b5fd7af3dfdf3 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:40 +0200 Subject: [PATCH 189/299] Expose Screenshot Assist in desktop login and settings with shared shortcut defaults. --- .../user-settings/components/DesktopContent.tsx | 3 ++- .../components/DesktopShortcutsContent.tsx | 5 +++-- surfsense_web/app/desktop/login/page.tsx | 10 ++++++++-- surfsense_web/components/desktop/shortcut-recorder.tsx | 1 + 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx index 4ce6f386c..3368066c1 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopContent.tsx @@ -129,7 +129,8 @@ export function DesktopContent() { Default Search Space - Choose which search space General Assist and Quick Assist use by default. + Choose which search space General Assist, Screenshot Assist, and Quick Assist use by + default. diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx index 0b7f330d9..f1679cb15 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx @@ -1,6 +1,6 @@ "use client"; -import { Rocket, RotateCcw, Zap } from "lucide-react"; +import { Crop, Rocket, RotateCcw, Zap } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; @@ -9,11 +9,12 @@ import { ShortcutKbd } from "@/components/ui/shortcut-kbd"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; -type ShortcutKey = "generalAssist" | "quickAsk"; +type ShortcutKey = "generalAssist" | "quickAsk" | "screenshotAssist"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; icon: React.ElementType }> = [ { key: "generalAssist", label: "General Assist", icon: Rocket }, + { key: "screenshotAssist", label: "Screenshot Assist", icon: Crop }, { key: "quickAsk", label: "Quick Assist", icon: Zap }, ]; diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index edb6cffab..c8ec4dfce 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -2,7 +2,7 @@ import { IconBrandGoogleFilled } from "@tabler/icons-react"; import { useAtom } from "jotai"; -import { Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react"; +import { Crop, Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react"; import Image from "next/image"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; @@ -21,7 +21,7 @@ import { setBearerToken } from "@/lib/auth-utils"; import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; const isGoogleAuth = AUTH_TYPE === "GOOGLE"; -type ShortcutKey = "generalAssist" | "quickAsk"; +type ShortcutKey = "generalAssist" | "quickAsk" | "screenshotAssist"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; const HOTKEY_ROWS: Array<{ @@ -36,6 +36,12 @@ const HOTKEY_ROWS: Array<{ description: "Launch SurfSense instantly from any application", icon: Rocket, }, + { + key: "screenshotAssist", + label: "Screenshot Assist", + description: "Draw a region on screen to attach that capture to chat", + icon: Crop, + }, { key: "quickAsk", label: "Quick Assist", diff --git a/surfsense_web/components/desktop/shortcut-recorder.tsx b/surfsense_web/components/desktop/shortcut-recorder.tsx index 119cd298f..388bb1bf8 100644 --- a/surfsense_web/components/desktop/shortcut-recorder.tsx +++ b/surfsense_web/components/desktop/shortcut-recorder.tsx @@ -38,6 +38,7 @@ export function acceleratorToDisplay(accel: string): string[] { export const DEFAULT_SHORTCUTS = { generalAssist: "CommandOrControl+Shift+S", quickAsk: "CommandOrControl+Alt+S", + screenshotAssist: "CommandOrControl+Shift+Space", }; // --------------------------------------------------------------------------- From d310663993e35748b21c290c4fa6f49118c64cb6 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:43 +0200 Subject: [PATCH 190/299] Wire Electron chat thread to screen capture events and full-screen capture from the composer. --- .../components/assistant-ui/thread.tsx | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 6862662f2..e7ae2f471 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -12,11 +12,11 @@ import { AlertCircle, ArrowDownIcon, ArrowUpIcon, + Camera, ChevronDown, ChevronUp, Clipboard, Dot, - Camera, Globe, Plus, Settings2, @@ -803,9 +803,11 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false isComposerTextEmpty && mentionedDocuments.length === 0 && pendingScreenImages.length === 0; const handleScreenCapture = useCallback(async () => { - const url = await captureDisplayToPngDataUrl(); + const url = electronAPI?.captureFullScreen + ? await electronAPI.captureFullScreen() + : await captureDisplayToPngDataUrl(); if (url) setPendingScreenImages((prev) => [...prev, url]); - }, [setPendingScreenImages]); + }, [electronAPI, setPendingScreenImages]); const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); @@ -1241,20 +1243,17 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false
)}
- {/* Electron: native shortcut → pending images; skip in-webview getDisplayMedia. */} - {!electronAPI && ( - void handleScreenCapture()} - > - - - )} + void handleScreenCapture()} + > + + !thread.isRunning}> Date: Mon, 27 Apr 2026 18:49:50 +0200 Subject: [PATCH 191/299] Rename homepage tutorial media to screenshot assist and point the hero tab at the new asset. --- surfsense_web/components/homepage/hero-section.tsx | 6 +++--- .../{extreme_assist.gif => screenshot_assist.gif} | Bin .../{extreme_assist.mp4 => screenshot_assist.mp4} | Bin 3 files changed, 3 insertions(+), 3 deletions(-) rename surfsense_web/public/homepage/hero_tutorial/{extreme_assist.gif => screenshot_assist.gif} (100%) rename surfsense_web/public/homepage/hero_tutorial/{extreme_assist.mp4 => screenshot_assist.mp4} (100%) diff --git a/surfsense_web/components/homepage/hero-section.tsx b/surfsense_web/components/homepage/hero-section.tsx index a29d02882..ec09fa34d 100644 --- a/surfsense_web/components/homepage/hero-section.tsx +++ b/surfsense_web/components/homepage/hero-section.tsx @@ -63,10 +63,10 @@ const TAB_ITEMS = [ featured: true, }, { - title: "Screen capture in chat", + title: "Screenshot Assist", description: - "Capture your screen and send it with your message so the AI can see what you see.", - src: "/homepage/hero_tutorial/extreme_assist.mp4", + "Use a global shortcut to select a region on your screen and attach it to your chat message.", + src: "/homepage/hero_tutorial/screenshot_assist.mp4", featured: true, }, { diff --git a/surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif b/surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif similarity index 100% rename from surfsense_web/public/homepage/hero_tutorial/extreme_assist.gif rename to surfsense_web/public/homepage/hero_tutorial/screenshot_assist.gif diff --git a/surfsense_web/public/homepage/hero_tutorial/extreme_assist.mp4 b/surfsense_web/public/homepage/hero_tutorial/screenshot_assist.mp4 similarity index 100% rename from surfsense_web/public/homepage/hero_tutorial/extreme_assist.mp4 rename to surfsense_web/public/homepage/hero_tutorial/screenshot_assist.mp4 From 36d891d41316c98a3beec847006a6195a171e91b Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:49:54 +0200 Subject: [PATCH 192/299] Update READMEs in all languages to describe Screenshot Assist instead of Extreme Assist. --- README.es.md | 10 +++++----- README.hi.md | 10 +++++----- README.md | 10 +++++----- README.pt-BR.md | 10 +++++----- README.zh-CN.md | 10 +++++----- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/README.es.md b/README.es.md index 299c6e95c..4e16af936 100644 --- a/README.es.md +++ b/README.es.md @@ -41,7 +41,7 @@ NotebookLM es una de las mejores y más útiles plataformas de IA que existen, p - **Sin Dependencia de Proveedores** - Configura cualquier modelo LLM, de imagen, TTS y STT. - **25+ Fuentes de Datos Externas** - Agrega tus fuentes desde Google Drive, OneDrive, Dropbox, Notion y muchos otros servicios externos. - **Soporte Multijugador en Tiempo Real** - Trabaja fácilmente con los miembros de tu equipo en un notebook compartido. -- **Aplicación de Escritorio** - Obtén asistencia de IA en cualquier aplicación con Quick Assist, General Assist, Extreme Assist y sincronización de carpetas locales. +- **Aplicación de Escritorio** - Obtén asistencia de IA en cualquier aplicación con Quick Assist, General Assist, Screenshot Assist y sincronización de carpetas locales. ...y más por venir. @@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7

Quick Assist

- - Aplicación de Escritorio — Extreme Assist + - Aplicación de Escritorio — Screenshot Assist -

Extreme Assist

+

Screenshot Assist

- Aplicación de Escritorio — Watch Local Folder @@ -150,7 +150,7 @@ La aplicación de escritorio incluye estas potentes funciones: - **General Assist** — Lanza SurfSense al instante desde cualquier aplicación con un atajo global. - **Quick Assist** — Selecciona texto en cualquier lugar, luego pide a la IA que lo explique, reescriba o actúe sobre él. -- **Extreme Assist** — Obtén sugerencias de escritura en línea impulsadas por tu base de conocimiento mientras escribes en cualquier aplicación. +- **Screenshot Assist** — Selecciona una región de tu pantalla y adjúntala al chat para que las respuestas se basen en tu base de conocimiento. - **Watch Local Folder** — Vigila una carpeta local y sincroniza automáticamente los cambios de archivos con tu base de conocimiento. **Pro tip:** Apúntalo a tu bóveda de Obsidian para mantener tus notas buscables en SurfSense. Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tus respuestas siempre están basadas en tus propios datos. @@ -199,7 +199,7 @@ Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tu | **Generación de Videos** | Resúmenes en video cinemáticos vía Veo 3 (solo Ultra) | Disponible (NotebookLM es mejor aquí, mejorando activamente) | | **Generación de Presentaciones** | Diapositivas más atractivas pero no editables | Crea presentaciones editables basadas en diapositivas | | **Generación de Podcasts** | Resúmenes de audio con hosts e idiomas personalizables | Disponible con múltiples proveedores TTS (NotebookLM es mejor aquí, mejorando activamente) | -| **Aplicación de Escritorio** | No | Aplicación nativa con General Assist, Quick Assist, Extreme Assist y sincronización de carpetas locales | +| **Aplicación de Escritorio** | No | Aplicación nativa con General Assist, Quick Assist, Screenshot Assist y sincronización de carpetas locales | | **Extensión de Navegador** | No | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación |
diff --git a/README.hi.md b/README.hi.md index 11a25ee0d..96f4d0da6 100644 --- a/README.hi.md +++ b/README.hi.md @@ -41,7 +41,7 @@ NotebookLM वहाँ उपलब्ध सबसे अच्छे और - **कोई विक्रेता लॉक-इन नहीं** - किसी भी LLM, इमेज, TTS और STT मॉडल को कॉन्फ़िगर करें। - **25+ बाहरी डेटा स्रोत** - Google Drive, OneDrive, Dropbox, Notion और कई अन्य बाहरी सेवाओं से अपने स्रोत जोड़ें। - **रीयल-टाइम मल्टीप्लेयर सपोर्ट** - एक साझा notebook में अपनी टीम के सदस्यों के साथ आसानी से काम करें। -- **डेस्कटॉप ऐप** - Quick Assist, General Assist, Extreme Assist और लोकल फ़ोल्डर सिंक के साथ किसी भी एप्लिकेशन में AI सहायता प्राप्त करें। +- **डेस्कटॉप ऐप** - Quick Assist, General Assist, Screenshot Assist और लोकल फ़ोल्डर सिंक के साथ किसी भी एप्लिकेशन में AI सहायता प्राप्त करें। ...और भी बहुत कुछ आने वाला है। @@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7

Quick Assist

- - डेस्कटॉप ऐप — Extreme Assist + - डेस्कटॉप ऐप — Screenshot Assist -

Extreme Assist

+

Screenshot Assist

- डेस्कटॉप ऐप — Watch Local Folder @@ -150,7 +150,7 @@ SurfSense एक डेस्कटॉप ऐप भी प्रदान क - **General Assist** — एक ग्लोबल शॉर्टकट से किसी भी एप्लिकेशन से तुरंत SurfSense लॉन्च करें। - **Quick Assist** — कहीं भी टेक्स्ट चुनें, फिर AI से समझाने, फिर से लिखने या उस पर कार्रवाई करने को कहें। -- **Extreme Assist** — किसी भी ऐप में टाइप करते समय अपनी नॉलेज बेस से संचालित इनलाइन लेखन सुझाव प्राप्त करें। +- **Screenshot Assist** — स्क्रीन पर एक क्षेत्र चुनें और उसे चैट में जोड़ें, ताकि उत्तर आपकी नॉलेज बेस पर आधारित रहें। - **Watch Local Folder** — एक लोकल फ़ोल्डर को वॉच करें और फ़ाइल परिवर्तनों को स्वचालित रूप से अपनी नॉलेज बेस में सिंक करें। **Pro tip:** इसे अपने Obsidian vault पर पॉइंट करें ताकि आपके नोट्स SurfSense में सर्च करने योग्य रहें। सभी सुविधाएं आपके चुने हुए सर्च स्पेस पर काम करती हैं, ताकि आपके उत्तर हमेशा आपके अपने डेटा पर आधारित हों। @@ -199,7 +199,7 @@ SurfSense एक डेस्कटॉप ऐप भी प्रदान क | **वीडियो जनरेशन** | Veo 3 के माध्यम से सिनेमैटिक वीडियो ओवरव्यू (केवल Ultra) | उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) | | **प्रेजेंटेशन जनरेशन** | बेहतर दिखने वाली स्लाइड्स लेकिन संपादन योग्य नहीं | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं | | **पॉडकास्ट जनरेशन** | कस्टमाइज़ेबल होस्ट और भाषाओं के साथ ऑडियो ओवरव्यू | कई TTS प्रदाताओं के साथ उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) | -| **डेस्कटॉप ऐप** | नहीं | General Assist, Quick Assist, Extreme Assist और लोकल फ़ोल्डर सिंक के साथ नेटिव ऐप | +| **डेस्कटॉप ऐप** | नहीं | General Assist, Quick Assist, Screenshot Assist और लोकल फ़ोल्डर सिंक के साथ नेटिव ऐप | | **ब्राउज़र एक्सटेंशन** | नहीं | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित |
diff --git a/README.md b/README.md index 9714b9e65..4dc9433ea 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ NotebookLM is one of the best and most useful AI platforms out there, but once y - **25+ External Data Sources** - Add your sources from Google Drive, OneDrive, Dropbox, Notion, and many other external services. - **Real-Time Multiplayer Support** - Work easily with your team members in a shared notebook. - **AI File Sorting** - Automatically organize your documents into a smart folder hierarchy using AI-powered categorization by source, date, and topic. -- **Desktop App** - Get AI assistance in any application with Quick Assist, General Assist, Extreme Assist, and local folder sync. +- **Desktop App** - Get AI assistance in any application with Quick Assist, General Assist, Screenshot Assist, and local folder sync. ...and more to come. @@ -85,9 +85,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7

Quick Assist

- - Desktop App — Extreme Assist + - Desktop App — Screenshot Assist -

Extreme Assist

+

Screenshot Assist

- Desktop App — Watch Local Folder @@ -151,7 +151,7 @@ The desktop app includes these powerful features: - **General Assist** — Launch SurfSense instantly from any application with a global shortcut. - **Quick Assist** — Select text anywhere, then ask AI to explain, rewrite, or act on it. -- **Extreme Assist** — Get inline writing suggestions powered by your knowledge base as you type in any app. +- **Screenshot Assist** — Select a region on your screen and attach it to chat so answers stay grounded in your knowledge base. - **Watch Local Folder** — Watch a local folder and automatically sync file changes to your knowledge base. **Pro tip:** Point it at your Obsidian vault to keep your notes searchable in SurfSense. All features operate against your chosen search space, so your answers are always grounded in your own data. @@ -201,7 +201,7 @@ All features operate against your chosen search space, so your answers are alway | **Presentation Generation** | Better looking slides but not editable | Create editable, slide-based presentations | | **Podcast Generation** | Audio Overviews with customizable hosts and languages | Available with multiple TTS providers (NotebookLM is better here, actively improving) | | **AI File Sorting** | No | LLM-powered auto-categorization into source, date, category, and subcategory folders | -| **Desktop App** | No | Native app with General Assist, Quick Assist, Extreme Assist, and local folder sync | +| **Desktop App** | No | Native app with General Assist, Quick Assist, Screenshot Assist, and local folder sync | | **Browser Extension** | No | Cross-browser extension to save any webpage, including auth-protected pages |
diff --git a/README.pt-BR.md b/README.pt-BR.md index 9323b2bce..d3cb36ad0 100644 --- a/README.pt-BR.md +++ b/README.pt-BR.md @@ -41,7 +41,7 @@ O NotebookLM é uma das melhores e mais úteis plataformas de IA disponíveis, m - **Sem Dependência de Fornecedor** - Configure qualquer modelo LLM, de imagem, TTS e STT. - **25+ Fontes de Dados Externas** - Adicione suas fontes do Google Drive, OneDrive, Dropbox, Notion e muitos outros serviços externos. - **Suporte Multiplayer em Tempo Real** - Trabalhe facilmente com os membros da sua equipe em um notebook compartilhado. -- **Aplicativo Desktop** - Obtenha assistência de IA em qualquer aplicativo com Quick Assist, General Assist, Extreme Assist e sincronização de pastas locais. +- **Aplicativo Desktop** - Obtenha assistência de IA em qualquer aplicativo com Quick Assist, General Assist, Screenshot Assist e sincronização de pastas locais. ...e mais por vir. @@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7

Quick Assist

- - Aplicativo Desktop — Extreme Assist + - Aplicativo Desktop — Screenshot Assist -

Extreme Assist

+

Screenshot Assist

- Aplicativo Desktop — Watch Local Folder @@ -150,7 +150,7 @@ O aplicativo desktop inclui estes recursos poderosos: - **General Assist** — Abra o SurfSense instantaneamente de qualquer aplicativo com um atalho global. - **Quick Assist** — Selecione texto em qualquer lugar, depois peça à IA para explicar, reescrever ou agir sobre ele. -- **Extreme Assist** — Receba sugestões de escrita em linha alimentadas pela sua base de conhecimento enquanto digita em qualquer aplicativo. +- **Screenshot Assist** — Selecione uma região da tela e anexe ao chat para respostas fundamentadas na sua base de conhecimento. - **Watch Local Folder** — Monitore uma pasta local e sincronize automaticamente as alterações de arquivos com sua base de conhecimento. **Pro tip:** Aponte para seu cofre do Obsidian para manter suas notas pesquisáveis no SurfSense. Todos os recursos operam no espaço de busca escolhido, para que suas respostas sejam sempre baseadas nos seus próprios dados. @@ -199,7 +199,7 @@ Todos os recursos operam no espaço de busca escolhido, para que suas respostas | **Geração de Vídeos** | Visões gerais cinemáticas via Veo 3 (apenas Ultra) | Disponível (NotebookLM é melhor aqui, melhorando ativamente) | | **Geração de Apresentações** | Slides mais bonitos mas não editáveis | Cria apresentações editáveis baseadas em slides | | **Geração de Podcasts** | Visões gerais em áudio com hosts e idiomas personalizáveis | Disponível com múltiplos provedores TTS (NotebookLM é melhor aqui, melhorando ativamente) | -| **Aplicativo Desktop** | Não | Aplicativo nativo com General Assist, Quick Assist, Extreme Assist e sincronização de pastas locais | +| **Aplicativo Desktop** | Não | Aplicativo nativo com General Assist, Quick Assist, Screenshot Assist e sincronização de pastas locais | | **Extensão de Navegador** | Não | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação |
diff --git a/README.zh-CN.md b/README.zh-CN.md index 29200243b..3e2bd095d 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -41,7 +41,7 @@ NotebookLM 是目前最好、最实用的 AI 平台之一,但当你开始经 - **无供应商锁定** - 配置任何 LLM、图像、TTS 和 STT 模型。 - **25+ 外部数据源** - 从 Google Drive、OneDrive、Dropbox、Notion 和许多其他外部服务添加你的来源。 - **实时多人协作支持** - 在共享笔记本中轻松与团队成员协作。 -- **桌面应用** - 通过 Quick Assist、General Assist、Extreme Assist 和本地文件夹同步在任何应用程序中获得 AI 助手。 +- **桌面应用** - 通过 Quick Assist、General Assist、Screenshot Assist 和本地文件夹同步在任何应用程序中获得 AI 助手。 ...更多功能即将推出。 @@ -84,9 +84,9 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7

Quick Assist

- - 桌面应用 — Extreme Assist + - 桌面应用 — Screenshot Assist -

Extreme Assist

+

Screenshot Assist

- 桌面应用 — Watch Local Folder @@ -150,7 +150,7 @@ SurfSense 还提供桌面应用,将 AI 助手带到您计算机上的每个应 - **General Assist** — 通过全局快捷键从任何应用程序即时启动 SurfSense。 - **Quick Assist** — 在任何位置选中文本,然后让 AI 解释、改写或对其执行操作。 -- **Extreme Assist** — 在任何应用中输入时,获得基于您知识库的内联写作建议。 +- **Screenshot Assist** — 在屏幕上框选区域并附加到聊天,让回复基于您的知识库。 - **Watch Local Folder** — 监视本地文件夹,自动将文件更改同步到您的知识库。**Pro tip:** 将其指向您的 Obsidian vault,让笔记在 SurfSense 中随时可搜索。 所有功能均基于您选择的搜索空间运行,确保回答始终以您自己的数据为依据。 @@ -199,7 +199,7 @@ SurfSense 还提供桌面应用,将 AI 助手带到您计算机上的每个应 | **视频生成** | 通过 Veo 3 的电影级视频概览(仅 Ultra) | 可用(NotebookLM 在此方面更好,正在积极改进) | | **演示文稿生成** | 更美观的幻灯片但不可编辑 | 创建可编辑的幻灯片式演示文稿 | | **播客生成** | 可自定义主持人和语言的音频概览 | 可用,支持多种 TTS 提供商(NotebookLM 在此方面更好,正在积极改进) | -| **桌面应用** | 否 | 原生应用,包含 General Assist、Quick Assist、Extreme Assist 和本地文件夹同步 | +| **桌面应用** | 否 | 原生应用,包含 General Assist、Quick Assist、Screenshot Assist 和本地文件夹同步 | | **浏览器扩展** | 否 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 |
From 9f5b6205e1cf3fe5a5e681ac8886129257f42709 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 18:50:01 +0200 Subject: [PATCH 193/299] Align macOS accessibility and screen capture usage strings with Screenshot Assist and chat. --- surfsense_desktop/electron-builder.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/surfsense_desktop/electron-builder.yml b/surfsense_desktop/electron-builder.yml index 2c46c827a..360519516 100644 --- a/surfsense_desktop/electron-builder.yml +++ b/surfsense_desktop/electron-builder.yml @@ -49,8 +49,8 @@ mac: hardenedRuntime: false gatekeeperAssess: false extendInfo: - NSAccessibilityUsageDescription: "SurfSense uses accessibility features to insert suggestions into the active application." - NSScreenCaptureUsageDescription: "SurfSense uses screen capture to analyze your screen and provide context-aware writing suggestions." + NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists." + NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer." NSAppleEventsUsageDescription: "SurfSense uses Apple Events to interact with the active application." target: - target: dmg From 3fa8c790f5cf7432a7413614b6bd7b92ddf482b8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 22:32:37 +0530 Subject: [PATCH 194/299] feat(filesystem): add move and list_tree functionalities to enhance local folder operations --- .../agents/new_chat/middleware/filesystem.py | 314 ++++++++++++++++++ .../middleware/local_folder_backend.py | 281 ++++++++++++++++ .../multi_root_local_folder_backend.py | 163 +++++++++ 3 files changed, 758 insertions(+) diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 1706e3705..d7bb339bd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -7,6 +7,7 @@ This middleware customizes prompts and persists write/edit operations for from __future__ import annotations import asyncio +import json import logging import re import secrets @@ -141,6 +142,33 @@ IMPORTANT: content. """ +SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder. + +Use absolute paths for both source and destination. + +Notes: +- In local-folder mode, paths should use mount prefixes (e.g., //foo.txt). +- Rename is a special case of move (same folder, different filename). +- Cross-mount moves are not supported. +""" + +SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively with cursor pagination. + +Use this in desktop local-folder mode to discover nested files at scale. + +Args: +- path: absolute mount-prefixed path (e.g., //src) or "/" for mount roots. +- max_depth: recursion depth limit (default 8). +- page_size: number of entries to return per page (max 1000). +- cursor: opaque continuation token from a previous call. +- include_files/include_dirs: filter returned entry types. + +Returns JSON with: +- entries: [{path, is_dir, size, modified_at, depth}] +- next_cursor: continuation token or null +- has_more: whether additional pages exist +""" + SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern. Supports standard glob patterns: `*`, `**`, `?`. @@ -222,11 +250,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ) if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: system_prompt += ( + "\n- move_file: move or rename files/folders in local-folder mode." + "\n- list_tree: recursively list nested local paths with cursor pagination." "\n\n## Local Folder Mode" "\n\nThis chat is running in desktop local-folder mode." " Keep all file operations local. Do not use save_document." " Always use mount-prefixed absolute paths like //file.ext." " If you are unsure which mounts are available, call ls('/') first." + " For big trees: use list_tree pages, then grep, then read_file." ) super().__init__( @@ -237,6 +268,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, "write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION, "edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION, + "move_file": SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION, + "list_tree": SURFSENSE_LIST_TREE_TOOL_DESCRIPTION, "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, "grep": SURFSENSE_GREP_TOOL_DESCRIPTION, }, @@ -244,6 +277,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): max_execute_timeout=self._MAX_EXECUTE_TIMEOUT, ) self.tools = [t for t in self.tools if t.name != "execute"] + if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + self.tools.append(self._create_move_file_tool()) + self.tools.append(self._create_list_tree_tool()) if self._should_persist_documents(): self.tools.append(self._create_save_document_tool()) if self._sandbox_available: @@ -836,6 +872,34 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return f"/{candidate.lstrip('/')}" return candidate + def _resolve_move_target_path( + self, + file_path: str, + runtime: ToolRuntime[None, FilesystemState], + ) -> str: + candidate = file_path.strip() + if not candidate: + return "" + if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + return self._normalize_local_mount_path(candidate, runtime) + if not candidate.startswith("/"): + return f"/{candidate.lstrip('/')}" + return candidate + + def _resolve_list_target_path( + self, + path: str, + runtime: ToolRuntime[None, FilesystemState], + ) -> str: + candidate = path.strip() or "/" + if candidate == "/": + return "/" + if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: + return self._normalize_local_mount_path(candidate, runtime) + if not candidate.startswith("/"): + return f"/{candidate.lstrip('/')}" + return candidate + @staticmethod def _is_error_text(value: str) -> bool: return value.startswith("Error:") @@ -930,6 +994,256 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ) return None, updated_content + def _create_move_file_tool(self) -> BaseTool: + """Create move_file for desktop local-folder mode.""" + tool_description = ( + self._custom_tool_descriptions.get("move_file") + or SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION + ) + + def sync_move_file( + source_path: Annotated[ + str, + "Absolute source path to move from.", + ], + destination_path: Annotated[ + str, + "Absolute destination path to move to.", + ], + runtime: ToolRuntime[None, FilesystemState], + *, + overwrite: Annotated[ + bool, + "If True, replace an existing destination file. Defaults to False.", + ] = False, + ) -> Command | str: + if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: + return "Error: move_file is only available in desktop local-folder mode." + + if not source_path.strip() or not destination_path.strip(): + return "Error: source_path and destination_path are required." + + resolved_backend = self._get_backend(runtime) + source_target = self._resolve_move_target_path(source_path, runtime) + destination_target = self._resolve_move_target_path(destination_path, runtime) + try: + validated_source = validate_path(source_target) + validated_destination = validate_path(destination_target) + except ValueError as exc: + return f"Error: {exc}" + res: WriteResult = resolved_backend.move( + validated_source, + validated_destination, + overwrite=overwrite, + ) + if res.error: + return res.error + if res.files_update is not None: + return Command( + update={ + "files": res.files_update, + "messages": [ + ToolMessage( + content=( + f"Moved '{validated_source}' to " + f"'{res.path or validated_destination}'" + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + return f"Moved '{validated_source}' to '{res.path or validated_destination}'" + + async def async_move_file( + source_path: Annotated[ + str, + "Absolute source path to move from.", + ], + destination_path: Annotated[ + str, + "Absolute destination path to move to.", + ], + runtime: ToolRuntime[None, FilesystemState], + *, + overwrite: Annotated[ + bool, + "If True, replace an existing destination file. Defaults to False.", + ] = False, + ) -> Command | str: + if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: + return "Error: move_file is only available in desktop local-folder mode." + + if not source_path.strip() or not destination_path.strip(): + return "Error: source_path and destination_path are required." + + resolved_backend = self._get_backend(runtime) + source_target = self._resolve_move_target_path(source_path, runtime) + destination_target = self._resolve_move_target_path(destination_path, runtime) + try: + validated_source = validate_path(source_target) + validated_destination = validate_path(destination_target) + except ValueError as exc: + return f"Error: {exc}" + res: WriteResult = await resolved_backend.amove( + validated_source, + validated_destination, + overwrite=overwrite, + ) + if res.error: + return res.error + if res.files_update is not None: + return Command( + update={ + "files": res.files_update, + "messages": [ + ToolMessage( + content=( + f"Moved '{validated_source}' to " + f"'{res.path or validated_destination}'" + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + return f"Moved '{validated_source}' to '{res.path or validated_destination}'" + + return StructuredTool.from_function( + name="move_file", + description=tool_description, + func=sync_move_file, + coroutine=async_move_file, + ) + + def _create_list_tree_tool(self) -> BaseTool: + """Create list_tree for desktop local-folder mode.""" + tool_description = ( + self._custom_tool_descriptions.get("list_tree") + or SURFSENSE_LIST_TREE_TOOL_DESCRIPTION + ) + + def sync_list_tree( + runtime: ToolRuntime[None, FilesystemState], + *, + path: Annotated[ + str, + "Absolute path to list from. Use '/' for mount roots.", + ] = "/", + max_depth: Annotated[ + int, + "Maximum recursion depth to traverse. Defaults to 8.", + ] = 8, + page_size: Annotated[ + int, + "Number of entries to return per page. Defaults to 500 (max 1000).", + ] = 500, + cursor: Annotated[ + str | None, + "Opaque cursor from a previous list_tree call.", + ] = None, + include_files: Annotated[ + bool, + "Whether file entries should be included.", + ] = True, + include_dirs: Annotated[ + bool, + "Whether directory entries should be included.", + ] = True, + ) -> str: + if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: + return "Error: list_tree is only available in desktop local-folder mode." + if max_depth < 0: + return "Error: max_depth must be >= 0." + if page_size < 1: + return "Error: page_size must be >= 1." + if not include_files and not include_dirs: + return "Error: include_files and include_dirs cannot both be false." + + resolved_backend = self._get_backend(runtime) + target_path = self._resolve_list_target_path(path, runtime) + try: + validated_path = validate_path(target_path) + except ValueError as exc: + return f"Error: {exc}" + + result = resolved_backend.list_tree( + validated_path, + max_depth=max_depth, + page_size=page_size, + cursor=cursor, + include_files=include_files, + include_dirs=include_dirs, + ) + error = result.get("error") if isinstance(result, dict) else None + if isinstance(error, str) and error: + return error + return json.dumps(result, ensure_ascii=True) + + async def async_list_tree( + runtime: ToolRuntime[None, FilesystemState], + *, + path: Annotated[ + str, + "Absolute path to list from. Use '/' for mount roots.", + ] = "/", + max_depth: Annotated[ + int, + "Maximum recursion depth to traverse. Defaults to 8.", + ] = 8, + page_size: Annotated[ + int, + "Number of entries to return per page. Defaults to 500 (max 1000).", + ] = 500, + cursor: Annotated[ + str | None, + "Opaque cursor from a previous list_tree call.", + ] = None, + include_files: Annotated[ + bool, + "Whether file entries should be included.", + ] = True, + include_dirs: Annotated[ + bool, + "Whether directory entries should be included.", + ] = True, + ) -> str: + if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: + return "Error: list_tree is only available in desktop local-folder mode." + if max_depth < 0: + return "Error: max_depth must be >= 0." + if page_size < 1: + return "Error: page_size must be >= 1." + if not include_files and not include_dirs: + return "Error: include_files and include_dirs cannot both be false." + + resolved_backend = self._get_backend(runtime) + target_path = self._resolve_list_target_path(path, runtime) + try: + validated_path = validate_path(target_path) + except ValueError as exc: + return f"Error: {exc}" + + result = await resolved_backend.alist_tree( + validated_path, + max_depth=max_depth, + page_size=page_size, + cursor=cursor, + include_files=include_files, + include_dirs=include_dirs, + ) + error = result.get("error") if isinstance(result, dict) else None + if isinstance(error, str) and error: + return error + return json.dumps(result, ensure_ascii=True) + + return StructuredTool.from_function( + name="list_tree", + description=tool_description, + func=sync_list_tree, + coroutine=async_list_tree, + ) + def _create_edit_file_tool(self) -> BaseTool: """Create edit_file with DB persistence (skipped for KB documents).""" tool_description = ( diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 60d967053..ef6a1657d 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -6,7 +6,12 @@ import asyncio import fnmatch import os import threading +from collections import deque +from contextlib import ExitStack from pathlib import Path +from time import time +from typing import Any +from uuid import uuid4 from deepagents.backends.protocol import ( EditResult, @@ -38,6 +43,8 @@ class LocalFolderBackend: self._root = root self._locks: dict[str, threading.Lock] = {} self._locks_mu = threading.Lock() + self._tree_sessions: dict[str, dict[str, Any]] = {} + self._tree_sessions_ttl_s = 900 def _lock_for(self, path: str) -> threading.Lock: with self._locks_mu: @@ -71,6 +78,54 @@ class LocalFolderBackend: temp_path.write_text(content, encoding="utf-8") os.replace(temp_path, path) + def _acquire_path_locks(self, *paths: str) -> ExitStack: + ordered_paths = sorted(set(paths)) + stack = ExitStack() + for path in ordered_paths: + stack.enter_context(self._lock_for(path)) + return stack + + @staticmethod + def _clamp_page_size(page_size: int) -> int: + return max(1, min(page_size, 1000)) + + def _prune_expired_tree_sessions(self) -> None: + now = time() + expired = [ + cursor + for cursor, session in self._tree_sessions.items() + if now - float(session.get("last_accessed_at", now)) > self._tree_sessions_ttl_s + ] + for cursor in expired: + self._tree_sessions.pop(cursor, None) + + def _read_dir_entries(self, directory_path: str) -> list[dict[str, Any]]: + directory = Path(directory_path) + try: + children = sorted( + directory.iterdir(), + key=lambda p: (not p.is_dir(), p.name.lower()), + ) + except OSError: + return [] + + entries: list[dict[str, Any]] = [] + for child in children: + try: + stat_result = child.stat() + except OSError: + continue + entries.append( + { + "path": self._to_virtual(child, self._root), + "is_dir": child.is_dir(), + "size": stat_result.st_size if child.is_file() else 0, + "modified_at": str(stat_result.st_mtime), + "absolute_path": str(child), + } + ) + return entries + def ls_info(self, path: str) -> list[FileInfo]: try: target = self._resolve_virtual(path, allow_root=True) @@ -145,6 +200,232 @@ class LocalFolderBackend: async def awrite(self, file_path: str, content: str) -> WriteResult: return await asyncio.to_thread(self.write, file_path, content) + def list_tree( + self, + path: str = "/", + *, + max_depth: int | None = 8, + page_size: int = 500, + cursor: str | None = None, + include_files: bool = True, + include_dirs: bool = True, + ) -> dict[str, Any]: + self._prune_expired_tree_sessions() + if not include_files and not include_dirs: + return { + "entries": [], + "next_cursor": None, + "has_more": False, + "truncated": False, + } + + normalized_depth = None if max_depth is None else max(0, int(max_depth)) + page_limit = self._clamp_page_size(int(page_size)) + now = time() + + if cursor: + session = self._tree_sessions.get(cursor) + if not session: + return {"error": "Invalid or expired cursor"} + if ( + session.get("path") != path + or session.get("max_depth") != normalized_depth + or session.get("include_files") != include_files + or session.get("include_dirs") != include_dirs + ): + return {"error": "Cursor options do not match request options"} + state = session + else: + try: + start = self._resolve_virtual(path, allow_root=True) + except ValueError: + return {"error": f"Error: invalid path '{path}'"} + if not start.exists(): + return {"error": f"Error: path '{path}' not found"} + if start.is_file(): + stat_result = start.stat() + if include_files: + return { + "entries": [ + { + "path": self._to_virtual(start, self._root), + "is_dir": False, + "size": stat_result.st_size, + "modified_at": str(stat_result.st_mtime), + "depth": 0, + } + ], + "next_cursor": None, + "has_more": False, + "truncated": False, + } + return { + "entries": [], + "next_cursor": None, + "has_more": False, + "truncated": False, + } + state = { + "path": path, + "max_depth": normalized_depth, + "include_files": include_files, + "include_dirs": include_dirs, + "pending_dirs": deque([(str(start), 0)]), + "active_dir": None, + "active_depth": 0, + "active_entries": [], + "active_index": 0, + } + + entries: list[dict[str, Any]] = [] + truncated = False + while len(entries) < page_limit: + active_entries = state.get("active_entries", []) + active_index = int(state.get("active_index", 0)) + if active_index >= len(active_entries): + pending_dirs = state.get("pending_dirs", []) + if not pending_dirs: + state["active_entries"] = [] + state["active_index"] = 0 + break + next_dir_path, next_depth = pending_dirs.popleft() + state["active_dir"] = next_dir_path + state["active_depth"] = next_depth + state["active_entries"] = self._read_dir_entries(next_dir_path) + state["active_index"] = 0 + active_entries = state["active_entries"] + active_index = 0 + + if active_index >= len(active_entries): + continue + + item = active_entries[active_index] + state["active_index"] = active_index + 1 + item_depth = int(state.get("active_depth", 0)) + 1 + if normalized_depth is not None and item_depth > normalized_depth: + continue + if item["is_dir"]: + if normalized_depth is None or item_depth <= normalized_depth: + state["pending_dirs"].append((item["absolute_path"], item_depth)) + if include_dirs: + entries.append( + { + "path": item["path"], + "is_dir": True, + "size": 0, + "modified_at": item["modified_at"], + "depth": item_depth, + } + ) + elif include_files: + entries.append( + { + "path": item["path"], + "is_dir": False, + "size": item["size"], + "modified_at": item["modified_at"], + "depth": item_depth, + } + ) + + if len(entries) >= page_limit: + truncated = True + break + + has_more = bool(state.get("pending_dirs")) or ( + int(state.get("active_index", 0)) < len(state.get("active_entries", [])) + ) + if has_more: + next_cursor = cursor or uuid4().hex + state["last_accessed_at"] = now + self._tree_sessions[next_cursor] = state + else: + next_cursor = None + if cursor: + self._tree_sessions.pop(cursor, None) + + return { + "entries": entries, + "next_cursor": next_cursor, + "has_more": has_more, + "truncated": truncated, + } + + async def alist_tree( + self, + path: str = "/", + *, + max_depth: int | None = 8, + page_size: int = 500, + cursor: str | None = None, + include_files: bool = True, + include_dirs: bool = True, + ) -> dict[str, Any]: + return await asyncio.to_thread( + self.list_tree, + path, + max_depth=max_depth, + page_size=page_size, + cursor=cursor, + include_files=include_files, + include_dirs=include_dirs, + ) + + def move( + self, + source_path: str, + destination_path: str, + overwrite: bool = False, + ) -> WriteResult: + try: + source = self._resolve_virtual(source_path) + destination = self._resolve_virtual(destination_path) + except ValueError: + return WriteResult( + error=( + f"Error: invalid source '{source_path}' or destination " + f"'{destination_path}' path" + ) + ) + if source == destination: + return WriteResult(error="Error: source and destination paths are the same") + with self._acquire_path_locks(source_path, destination_path): + if not source.exists(): + return WriteResult(error=f"Error: source path '{source_path}' not found") + if destination.exists(): + if not overwrite: + return WriteResult( + error=( + f"Error: destination path '{destination_path}' already exists. " + "Set overwrite=True to replace files." + ) + ) + if source.is_dir() or destination.is_dir(): + return WriteResult( + error=( + "Error: overwrite=True is only supported for file-to-file moves." + ) + ) + destination.parent.mkdir(parents=True, exist_ok=True) + try: + if overwrite: + os.replace(source, destination) + else: + source.rename(destination) + except OSError as exc: + return WriteResult(error=f"Error: failed to move '{source_path}': {exc}") + return WriteResult(path=self._to_virtual(destination, self._root), files_update=None) + + async def amove( + self, + source_path: str, + destination_path: str, + overwrite: bool = False, + ) -> WriteResult: + return await asyncio.to_thread( + self.move, source_path, destination_path, overwrite + ) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 12632f00f..6760d76f0 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -3,6 +3,8 @@ from __future__ import annotations import asyncio +import base64 +import json from pathlib import Path from typing import Any @@ -107,6 +109,28 @@ class MultiRootLocalFolderBackend: for mount in self._mount_order ] + @staticmethod + def _encode_tree_cursor(mount: str, local_cursor: str) -> str: + payload = json.dumps( + {"mount": mount, "cursor": local_cursor}, + separators=(",", ":"), + ).encode("utf-8") + return base64.urlsafe_b64encode(payload).decode("ascii") + + @staticmethod + def _decode_tree_cursor(cursor: str) -> tuple[str, str]: + try: + padded = cursor + "=" * ((4 - len(cursor) % 4) % 4) + data = base64.urlsafe_b64decode(padded.encode("ascii")) + parsed = json.loads(data.decode("utf-8")) + except Exception as exc: + raise ValueError("Invalid cursor") from exc + mount = parsed.get("mount") + local_cursor = parsed.get("cursor") + if not isinstance(mount, str) or not isinstance(local_cursor, str): + raise ValueError("Invalid cursor") + return mount, local_cursor + def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]: transformed: list[FileInfo] = [] for info in infos: @@ -132,6 +156,103 @@ class MultiRootLocalFolderBackend: async def als_info(self, path: str) -> list[FileInfo]: return await asyncio.to_thread(self.ls_info, path) + def list_tree( + self, + path: str = "/", + *, + max_depth: int | None = 8, + page_size: int = 500, + cursor: str | None = None, + include_files: bool = True, + include_dirs: bool = True, + ) -> dict[str, Any]: + if path == "/" and not cursor: + entries = [ + { + "path": f"/{mount}", + "is_dir": True, + "size": 0, + "modified_at": "0", + "depth": 0, + } + for mount in self._mount_order + ] + return { + "entries": entries if include_dirs else [], + "next_cursor": None, + "has_more": False, + "truncated": False, + } + + try: + if cursor: + mount, local_cursor = self._decode_tree_cursor(cursor) + if mount not in self._mount_to_backend: + return {"error": "Invalid or expired cursor"} + local_path = "/" + else: + mount, local_path = self._split_mount_path(path) + local_cursor = None + except ValueError as exc: + return {"error": f"Error: {exc}"} + + result = self._mount_to_backend[mount].list_tree( + local_path, + max_depth=max_depth, + page_size=page_size, + cursor=local_cursor, + include_files=include_files, + include_dirs=include_dirs, + ) + if result.get("error"): + return result + + entries: list[dict[str, Any]] = [] + for entry in result.get("entries", []): + raw_path = self._get_str(entry, "path") + entries.append( + { + "path": self._prefix_mount_path(mount, raw_path), + "is_dir": self._get_bool(entry, "is_dir"), + "size": self._get_int(entry, "size"), + "modified_at": self._get_str(entry, "modified_at"), + "depth": self._get_int(entry, "depth"), + } + ) + + local_next_cursor = self._get_str(result, "next_cursor") + next_cursor = ( + self._encode_tree_cursor(mount, local_next_cursor) + if local_next_cursor + else None + ) + return { + "entries": entries, + "next_cursor": next_cursor, + "has_more": self._get_bool(result, "has_more"), + "truncated": self._get_bool(result, "truncated"), + } + + async def alist_tree( + self, + path: str = "/", + *, + max_depth: int | None = 8, + page_size: int = 500, + cursor: str | None = None, + include_files: bool = True, + include_dirs: bool = True, + ) -> dict[str, Any]: + return await asyncio.to_thread( + self.list_tree, + path, + max_depth=max_depth, + page_size=page_size, + cursor=cursor, + include_files=include_files, + include_dirs=include_dirs, + ) + def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: try: mount, local_path = self._split_mount_path(file_path) @@ -165,6 +286,48 @@ class MultiRootLocalFolderBackend: async def awrite(self, file_path: str, content: str) -> WriteResult: return await asyncio.to_thread(self.write, file_path, content) + def move( + self, + source_path: str, + destination_path: str, + overwrite: bool = False, + ) -> WriteResult: + try: + source_mount, source_local_path = self._split_mount_path(source_path) + destination_mount, destination_local_path = self._split_mount_path( + destination_path + ) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + if source_mount != destination_mount: + return WriteResult( + error=( + "Error: cross-mount moves are not supported. " + "Source and destination must be under the same mounted root." + ) + ) + result = self._mount_to_backend[source_mount].move( + source_local_path, + destination_local_path, + overwrite=overwrite, + ) + if result.path: + result.path = self._prefix_mount_path(source_mount, result.path) + return result + + async def amove( + self, + source_path: str, + destination_path: str, + overwrite: bool = False, + ) -> WriteResult: + return await asyncio.to_thread( + self.move, + source_path, + destination_path, + overwrite, + ) + def edit( self, file_path: str, From 056870464ab77dff01ccbc7a484db27174c4a0d8 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 19:25:20 +0200 Subject: [PATCH 195/299] Accept optional user_images on regenerate and apply them when resolving the model turn. --- surfsense_backend/app/routes/new_chat_routes.py | 3 +++ surfsense_backend/app/schemas/new_chat.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 854627d4b..cbc660222 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1456,6 +1456,9 @@ async def regenerate_response( user_query_to_use ) + if request.user_images is not None: + regenerate_image_urls = [p.as_data_url() for p in request.user_images] + if user_query_to_use is None: raise HTTPException( status_code=400, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index e757ce178..477fdf2ca 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -238,6 +238,9 @@ class RegenerateRequest(BaseModel): 2. Reload: Leave user_query empty to regenerate the last AI response with the same query Both operations rewind the LangGraph checkpointer to the appropriate state. + + For edit, optional user_images (when not None) replaces image URLs resolved from + checkpoint/DB so the client can send the full user turn (text and/or images). """ search_space_id: int @@ -250,6 +253,16 @@ class RegenerateRequest(BaseModel): filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None + user_images: list[NewChatUserImagePart] | None = Field( + default=None, + description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB", + ) + + @model_validator(mode="after") + def _validate_regenerate_user_images(self) -> Self: + if self.user_images is not None and len(self.user_images) > MAX_NEW_CHAT_IMAGES: + raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed") + return self # ============================================================================= From a07c44f4965daafc70512f9dcbc64028c6e28a4e Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 19:25:26 +0200 Subject: [PATCH 196/299] Send edited user images and full message content in chat regenerate while leaving reload on server-resolved turns. --- .../new-chat/[[...chat_id]]/page.tsx | 74 +++++++++++-------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index da134c4cf..10abe13b1 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -77,7 +77,10 @@ import { type ThreadListResponse, type ThreadRecord, } from "@/lib/chat/thread-persistence"; -import { extractUserTurnForNewChatApi } from "@/lib/chat/user-turn-api-parts"; +import { + extractUserTurnForNewChatApi, + type NewChatUserImagePayload, +} from "@/lib/chat/user-turn-api-parts"; import { NotFoundError } from "@/lib/error"; import { trackChatCreated, @@ -1337,15 +1340,24 @@ export default function NewChatPage() { * Handle regeneration (edit or reload) by calling the regenerate endpoint * and streaming the response. This rewinds the LangGraph checkpointer state. * - * @param newUserQuery - The new user query (for edit). Pass null/undefined for reload. + * @param newUserQuery - `null` = reload with same turn from the server. A string = edit + * (including an empty string when the edited turn is images-only); pass `editExtras` for images/content. */ const handleRegenerate = useCallback( - async (newUserQuery?: string | null) => { + async ( + newUserQuery: string | null, + editExtras?: { + userMessageContent: ThreadMessageLike["content"]; + userImages: NewChatUserImagePayload[]; + } + ) => { if (!threadId) { toast.error("Cannot regenerate: no active chat thread"); return; } + const isEdit = newUserQuery !== null; + // Abort any previous streaming request if (abortControllerRef.current) { abortControllerRef.current.abort(); @@ -1359,11 +1371,11 @@ export default function NewChatPage() { } // Extract the original user query BEFORE removing messages (for reload mode) - let userQueryToDisplay = newUserQuery; + let userQueryToDisplay: string | undefined; let originalUserMessageContent: ThreadMessageLike["content"] | null = null; let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined; - if (!newUserQuery) { + if (!isEdit) { // Reload mode - find and preserve the last user message content const lastUserMessage = [...messages].reverse().find((m) => m.role === "user"); if (lastUserMessage) { @@ -1377,6 +1389,8 @@ export default function NewChatPage() { } } } + } else { + userQueryToDisplay = newUserQuery; } // Remove the last two messages (user + assistant) from the UI immediately @@ -1412,11 +1426,13 @@ export default function NewChatPage() { const userMessage: ThreadMessageLike = { id: userMsgId, role: "user", - content: newUserQuery - ? [{ type: "text", text: newUserQuery }] + content: isEdit + ? (editExtras?.userMessageContent ?? [ + { type: "text", text: newUserQuery ?? "" }, + ]) : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }], createdAt: new Date(), - metadata: newUserQuery ? undefined : originalUserMessageMetadata, + metadata: isEdit ? undefined : originalUserMessageMetadata, }; setMessages((prev) => [...prev, userMessage]); @@ -1433,20 +1449,24 @@ export default function NewChatPage() { try { const selection = await getAgentFilesystemSelection(); + const requestBody: Record = { + search_space_id: searchSpaceId, + user_query: newUserQuery, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + }; + if (isEdit) { + requestBody.user_images = editExtras?.userImages ?? []; + } const response = await fetch(getRegenerateUrl(threadId), { method: "POST", headers: { "Content-Type": "application/json", Authorization: `Bearer ${token}`, }, - body: JSON.stringify({ - search_space_id: searchSpaceId, - user_query: newUserQuery || null, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - }), + body: JSON.stringify(requestBody), signal: controller.signal, }); @@ -1536,8 +1556,10 @@ export default function NewChatPage() { if (contentParts.length > 0) { try { // Persist user message (for both edit and reload modes, since backend deleted it) - const userContentToPersist = newUserQuery - ? [{ type: "text", text: newUserQuery }] + const userContentToPersist = isEdit + ? (editExtras?.userMessageContent ?? [ + { type: "text", text: newUserQuery ?? "" }, + ]) : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; const savedUserMessage = await appendMessage(threadId, { @@ -1602,21 +1624,15 @@ export default function NewChatPage() { // Handle editing a message - truncates history and regenerates with new query const onEdit = useCallback( async (message: AppendMessage) => { - // Extract the new user query from the message content - let newUserQuery = ""; - for (const part of message.content) { - if (part.type === "text") { - newUserQuery += part.text; - } - } - - if (!newUserQuery.trim()) { + const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []); + const queryForApi = userQuery.trim(); + if (!queryForApi && userImages.length === 0) { toast.error("Cannot edit with empty message"); return; } - // Call regenerate with the new query - await handleRegenerate(newUserQuery.trim()); + const userMessageContent = message.content as unknown as ThreadMessageLike["content"]; + await handleRegenerate(queryForApi, { userMessageContent, userImages }); }, [handleRegenerate] ); From 8b542ca3dd77d2905d0a357cc438d385695cf42d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 19:25:39 +0200 Subject: [PATCH 197/299] Deduplicate user-turn images by full base64 data and update desktop permissions copy for Screenshot Assist. --- surfsense_web/app/desktop/permissions/page.tsx | 8 +++++--- surfsense_web/lib/chat/user-turn-api-parts.ts | 5 ++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/surfsense_web/app/desktop/permissions/page.tsx b/surfsense_web/app/desktop/permissions/page.tsx index a2fadc8ff..e30a76f83 100644 --- a/surfsense_web/app/desktop/permissions/page.tsx +++ b/surfsense_web/app/desktop/permissions/page.tsx @@ -19,14 +19,15 @@ const STEPS = [ id: "screen-recording", title: "Screen Recording", description: - "Lets SurfSense capture your screen to understand context and provide smart writing suggestions.", + "Lets SurfSense capture a region of your screen, full display, or browser (where supported) to attach to chat in Screenshot Assist, or to capture the full display from the composer.", action: "requestScreenRecording", field: "screenRecording" as const, }, { id: "accessibility", title: "Accessibility", - description: "Lets SurfSense insert suggestions seamlessly, right where you\u2019re typing.", + description: + "Lets SurfSense bring the app to the foreground and work with the active application (for example Quick Assist) when you use desktop shortcuts.", action: "requestAccessibility", field: "accessibility" as const, }, @@ -131,7 +132,8 @@ export default function DesktopPermissionsPage() {

System Permissions

- SurfSense needs two macOS permissions to provide context-aware writing suggestions. + SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that + require focusing the app or the active application.

diff --git a/surfsense_web/lib/chat/user-turn-api-parts.ts b/surfsense_web/lib/chat/user-turn-api-parts.ts index 48d27a7ba..5e063492f 100644 --- a/surfsense_web/lib/chat/user-turn-api-parts.ts +++ b/surfsense_web/lib/chat/user-turn-api-parts.ts @@ -46,9 +46,8 @@ export function extractUserTurnForNewChatApi( for (const url of merged) { const p = dataUrlToPayload(url); if (!p) continue; - const key = `${p.media_type}:${p.data.length}`; - if (seen.has(key)) continue; - seen.add(key); + if (seen.has(p.data)) continue; + seen.add(p.data); payloads.push(p); if (payloads.length >= MAX_IMAGES) break; } From f330d1431cb8476f066424651d5f6f6afa1c5d37 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 27 Apr 2026 23:08:32 +0530 Subject: [PATCH 198/299] feat(filesystem): implement filesystem tree watch functionality using chokidar for real-time updates on local folder changes --- surfsense_desktop/src/ipc/channels.ts | 3 + surfsense_desktop/src/ipc/handlers.ts | 17 + .../modules/agent-filesystem-tree-watcher.ts | 302 ++++++++++++++++++ surfsense_desktop/src/preload.ts | 32 ++ .../components/assistant-ui/markdown-text.tsx | 78 ++++- .../components/editor-panel/editor-panel.tsx | 64 +++- .../ui/sidebar/LocalFilesystemBrowser.tsx | 88 ++++- surfsense_web/types/window.d.ts | 22 ++ 8 files changed, 583 insertions(+), 23 deletions(-) create mode 100644 surfsense_desktop/src/modules/agent-filesystem-tree-watcher.ts diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index ec676fba8..ed4b49fad 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -57,6 +57,9 @@ export const IPC_CHANNELS = { AGENT_FILESYSTEM_GET_SETTINGS: 'agent-filesystem:get-settings', AGENT_FILESYSTEM_GET_MOUNTS: 'agent-filesystem:get-mounts', AGENT_FILESYSTEM_LIST_FILES: 'agent-filesystem:list-files', + AGENT_FILESYSTEM_TREE_WATCH_START: 'agent-filesystem:tree-watch-start', + AGENT_FILESYSTEM_TREE_WATCH_STOP: 'agent-filesystem:tree-watch-stop', + AGENT_FILESYSTEM_TREE_DIRTY: 'agent-filesystem:tree-dirty', AGENT_FILESYSTEM_SET_SETTINGS: 'agent-filesystem:set-settings', AGENT_FILESYSTEM_PICK_ROOT: 'agent-filesystem:pick-root', } as const; diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 4054255f4..2b06c7fb0 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -45,6 +45,11 @@ import { pickAgentFilesystemRoot, setAgentFilesystemSettings, } from '../modules/agent-filesystem'; +import { + startAgentFilesystemTreeWatch, + stopAgentFilesystemTreeWatch, + type AgentFilesystemTreeWatchOptions, +} from '../modules/agent-filesystem-tree-watcher'; let authTokens: { bearer: string; refresh: string } | null = null; @@ -263,4 +268,16 @@ export function registerIpcHandlers(): void { ipcMain.handle(IPC_CHANNELS.AGENT_FILESYSTEM_PICK_ROOT, () => pickAgentFilesystemRoot() ); + + ipcMain.handle( + IPC_CHANNELS.AGENT_FILESYSTEM_TREE_WATCH_START, + (_event, options: AgentFilesystemTreeWatchOptions) => + startAgentFilesystemTreeWatch(options) + ); + + ipcMain.handle( + IPC_CHANNELS.AGENT_FILESYSTEM_TREE_WATCH_STOP, + (_event, searchSpaceId?: number | null) => + stopAgentFilesystemTreeWatch(searchSpaceId) + ); } diff --git a/surfsense_desktop/src/modules/agent-filesystem-tree-watcher.ts b/surfsense_desktop/src/modules/agent-filesystem-tree-watcher.ts new file mode 100644 index 000000000..600f84fd5 --- /dev/null +++ b/surfsense_desktop/src/modules/agent-filesystem-tree-watcher.ts @@ -0,0 +1,302 @@ +import { BrowserWindow } from 'electron'; +import chokidar, { type FSWatcher } from 'chokidar'; +import { resolve } from 'node:path'; +import { IPC_CHANNELS } from '../ipc/channels'; +import { listAgentFilesystemFiles } from './agent-filesystem'; + +const SAFETY_POLL_MS = 60_000; +const EVENT_DEBOUNCE_MS = 700; + +export type AgentFilesystemTreeWatchOptions = { + searchSpaceId?: number | null; + rootPaths: string[]; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; +}; + +type TreeDirtyReason = 'watcher_event' | 'safety_poll'; + +type TreeDirtyEvent = { + searchSpaceId: number | null; + reason: TreeDirtyReason; + rootPath: string; + changedPath: string | null; + timestamp: number; +}; + +type WatchSession = { + searchSpaceId: number | null; + optionsSignature: string; + rootPaths: string[]; + excludePatterns: string[]; + fileExtensions: string[] | null; + watchers: FSWatcher[]; + pollTimer: NodeJS.Timeout | null; + emitTimer: NodeJS.Timeout | null; + rootSnapshotByPath: Map; + pendingDirtyByRoot: Map; + disposed: boolean; +}; + +const sessions = new Map(); + +function normalizeSearchSpaceId(searchSpaceId?: number | null): number | null { + if (typeof searchSpaceId === 'number' && Number.isFinite(searchSpaceId) && searchSpaceId > 0) { + return searchSpaceId; + } + return null; +} + +function getSessionKey(searchSpaceId?: number | null): string { + const normalized = normalizeSearchSpaceId(searchSpaceId); + return normalized === null ? 'default' : String(normalized); +} + +function normalizeRootPath(pathValue: string): string { + const normalized = resolve(pathValue.trim()); + return process.platform === 'win32' ? normalized.toLowerCase() : normalized; +} + +function normalizeList(value: string[] | null | undefined): string[] { + if (!value || value.length === 0) return []; + return value + .filter((entry): entry is string => typeof entry === 'string') + .map((entry) => entry.trim()) + .filter(Boolean); +} + +function normalizeExtensions(value: string[] | null | undefined): string[] | null { + const normalized = normalizeList(value).map((entry) => entry.toLowerCase()); + return normalized.length > 0 ? normalized : null; +} + +function buildOptionsSignature( + searchSpaceId: number | null, + rootPaths: string[], + excludePatterns: string[], + fileExtensions: string[] | null +): string { + return JSON.stringify({ + searchSpaceId, + rootPaths: [...rootPaths].sort(), + excludePatterns: [...excludePatterns].sort(), + fileExtensions: fileExtensions ? [...fileExtensions].sort() : null, + }); +} + +function hashText(input: string, seed: number): number { + let hash = seed >>> 0; + for (let i = 0; i < input.length; i += 1) { + hash ^= input.charCodeAt(i); + hash = Math.imul(hash, 16777619); + hash >>>= 0; + } + return hash; +} + +async function buildRootSnapshotSignature( + session: WatchSession, + rootPath: string +): Promise { + let hash = 2166136261; + hash = hashText(`space:${session.searchSpaceId ?? 'default'}|root:${rootPath}`, hash); + const files = await listAgentFilesystemFiles({ + rootPath, + searchSpaceId: session.searchSpaceId, + excludePatterns: session.excludePatterns, + fileExtensions: session.fileExtensions, + }); + const sortedFiles = [...files].sort((a, b) => a.relativePath.localeCompare(b.relativePath)); + hash = hashText(`count:${sortedFiles.length}`, hash); + for (const file of sortedFiles) { + hash = hashText( + `${file.relativePath}|${Math.round(file.mtimeMs)}|${file.size}`, + hash + ); + } + return hash.toString(16); +} + +function sendTreeDirtyEvent( + searchSpaceId: number | null, + reason: TreeDirtyReason, + rootPath: string, + changedPath: string | null +): void { + const payload: TreeDirtyEvent = { + searchSpaceId, + reason, + rootPath, + changedPath, + timestamp: Date.now(), + }; + for (const win of BrowserWindow.getAllWindows()) { + if (!win.isDestroyed()) { + win.webContents.send(IPC_CHANNELS.AGENT_FILESYSTEM_TREE_DIRTY, payload); + } + } +} + +function scheduleDirtyEmit( + session: WatchSession, + reason: TreeDirtyReason, + rootPath: string, + changedPath: string | null = null +): void { + if (session.disposed) return; + const existing = session.pendingDirtyByRoot.get(rootPath); + if (!existing || existing.reason === 'safety_poll') { + session.pendingDirtyByRoot.set(rootPath, { reason, changedPath }); + } + if (session.emitTimer) { + clearTimeout(session.emitTimer); + } + session.emitTimer = setTimeout(() => { + session.emitTimer = null; + if (session.disposed) return; + const pending = Array.from(session.pendingDirtyByRoot.entries()); + session.pendingDirtyByRoot.clear(); + for (const [pendingRootPath, payload] of pending) { + sendTreeDirtyEvent( + session.searchSpaceId, + payload.reason, + pendingRootPath, + payload.changedPath + ); + } + }, EVENT_DEBOUNCE_MS); +} + +async function closeSession(session: WatchSession): Promise { + session.disposed = true; + if (session.emitTimer) { + clearTimeout(session.emitTimer); + session.emitTimer = null; + } + if (session.pollTimer) { + clearInterval(session.pollTimer); + session.pollTimer = null; + } + await Promise.allSettled(session.watchers.map((watcher) => watcher.close())); +} + +export async function startAgentFilesystemTreeWatch( + options: AgentFilesystemTreeWatchOptions +): Promise<{ ok: true }> { + const searchSpaceId = normalizeSearchSpaceId(options.searchSpaceId); + const rootPaths = Array.from( + new Set(normalizeList(options.rootPaths).map((rootPath) => normalizeRootPath(rootPath))) + ); + const excludePatterns = Array.from(new Set(normalizeList(options.excludePatterns))); + const fileExtensions = normalizeExtensions(options.fileExtensions); + const sessionKey = getSessionKey(searchSpaceId); + + if (rootPaths.length === 0) { + await stopAgentFilesystemTreeWatch(searchSpaceId); + return { ok: true }; + } + + const optionsSignature = buildOptionsSignature( + searchSpaceId, + rootPaths, + excludePatterns, + fileExtensions + ); + const existing = sessions.get(sessionKey); + if (existing && existing.optionsSignature === optionsSignature) { + return { ok: true }; + } + if (existing) { + await closeSession(existing); + sessions.delete(sessionKey); + } + + const ignored = [ + /(^|[/\\])\../, + ...excludePatterns.map((pattern) => `**/${pattern}/**`), + ]; + const watchers = rootPaths.map((rootPath) => + chokidar.watch(rootPath, { + persistent: true, + ignoreInitial: true, + awaitWriteFinish: { + stabilityThreshold: 500, + pollInterval: 100, + }, + ignored, + }) + ); + + const session: WatchSession = { + searchSpaceId, + optionsSignature, + rootPaths, + excludePatterns, + fileExtensions, + watchers, + pollTimer: null, + emitTimer: null, + rootSnapshotByPath: new Map(), + pendingDirtyByRoot: new Map(), + disposed: false, + }; + + for (let index = 0; index < watchers.length; index += 1) { + const watcher = watchers[index]; + const rootPath = rootPaths[index]; + watcher.on('add', (filePath) => scheduleDirtyEmit(session, 'watcher_event', rootPath, filePath)); + watcher.on('change', (filePath) => + scheduleDirtyEmit(session, 'watcher_event', rootPath, filePath) + ); + watcher.on('unlink', (filePath) => + scheduleDirtyEmit(session, 'watcher_event', rootPath, filePath) + ); + watcher.on('addDir', (filePath) => + scheduleDirtyEmit(session, 'watcher_event', rootPath, filePath) + ); + watcher.on('unlinkDir', (filePath) => + scheduleDirtyEmit(session, 'watcher_event', rootPath, filePath) + ); + } + + for (const rootPath of rootPaths) { + try { + const signature = await buildRootSnapshotSignature(session, rootPath); + session.rootSnapshotByPath.set(rootPath, signature); + } catch { + session.rootSnapshotByPath.set(rootPath, ''); + } + } + + session.pollTimer = setInterval(() => { + void (async () => { + if (session.disposed) return; + for (const rootPath of session.rootPaths) { + try { + const nextSignature = await buildRootSnapshotSignature(session, rootPath); + const previousSignature = session.rootSnapshotByPath.get(rootPath) ?? ''; + if (nextSignature !== previousSignature) { + session.rootSnapshotByPath.set(rootPath, nextSignature); + scheduleDirtyEmit(session, 'safety_poll', rootPath, null); + } + } catch { + // Keep watcher resilient on transient IO errors. + } + } + })(); + }, SAFETY_POLL_MS); + + sessions.set(sessionKey, session); + return { ok: true }; +} + +export async function stopAgentFilesystemTreeWatch( + searchSpaceId?: number | null +): Promise<{ ok: true }> { + const sessionKey = getSessionKey(searchSpaceId); + const session = sessions.get(sessionKey); + if (!session) return { ok: true }; + sessions.delete(sessionKey); + await closeSession(session); + return { ok: true }; +} diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index 8e5c2f56b..100825c0f 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -116,6 +116,38 @@ contextBridge.exposeInMainWorld('electronAPI', { excludePatterns?: string[] | null; fileExtensions?: string[] | null; }) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_LIST_FILES, options), + startAgentFilesystemTreeWatch: (options: { + searchSpaceId?: number | null; + rootPaths: string[]; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; + }) => ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_TREE_WATCH_START, options), + stopAgentFilesystemTreeWatch: (searchSpaceId?: number | null) => + ipcRenderer.invoke(IPC_CHANNELS.AGENT_FILESYSTEM_TREE_WATCH_STOP, searchSpaceId), + onAgentFilesystemTreeDirty: ( + callback: (data: { + searchSpaceId: number | null; + reason: 'watcher_event' | 'safety_poll'; + rootPath: string; + changedPath: string | null; + timestamp: number; + }) => void + ) => { + const listener = ( + _event: unknown, + data: { + searchSpaceId: number | null; + reason: 'watcher_event' | 'safety_poll'; + rootPath: string; + changedPath: string | null; + timestamp: number; + } + ) => callback(data); + ipcRenderer.on(IPC_CHANNELS.AGENT_FILESYSTEM_TREE_DIRTY, listener); + return () => { + ipcRenderer.removeListener(IPC_CHANNELS.AGENT_FILESYSTEM_TREE_DIRTY, listener); + }; + }, setAgentFilesystemSettings: (settings: { mode?: "cloud" | "desktop_local_folder"; localRootPaths?: string[] | null; diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index a15ff1cd7..2707e8956 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -229,6 +229,44 @@ function extractDomain(url: string): string { // Canonical local-file virtual paths are mount-prefixed: // const LOCAL_FILE_PATH_REGEX = /^\/[a-z0-9_-]+\/[^\s`]+(?:\/[^\s`]+)*$/; +type AgentFilesystemMount = { + mount: string; + rootPath: string; +}; + +function normalizeLocalVirtualPathForEditor( + candidatePath: string, + mounts: AgentFilesystemMount[] +): string { + const normalizedCandidate = candidatePath.trim().replace(/\\/g, "/").replace(/\/+/g, "/"); + if (!normalizedCandidate) { + return candidatePath; + } + const defaultMount = mounts[0]?.mount; + if (!defaultMount) { + return normalizedCandidate.startsWith("/") + ? normalizedCandidate + : `/${normalizedCandidate.replace(/^\/+/, "")}`; + } + + const mountNames = new Set(mounts.map((entry) => entry.mount)); + if (normalizedCandidate.startsWith("/")) { + const relative = normalizedCandidate.replace(/^\/+/, ""); + const [firstSegment] = relative.split("/", 1); + if (mountNames.has(firstSegment)) { + return `/${relative}`; + } + return `/${defaultMount}/${relative}`; + } + + const relative = normalizedCandidate.replace(/^\/+/, ""); + const [firstSegment] = relative.split("/", 1); + if (mountNames.has(firstSegment)) { + return `/${relative}`; + } + return `/${defaultMount}/${relative}`; +} + function isVirtualFilePathToken(value: string): boolean { if (!LOCAL_FILE_PATH_REGEX.test(value) || value.startsWith("//")) { return false; @@ -421,8 +459,15 @@ const defaultComponents = memoizeMarkdownComponents({ !codeString.includes("\n"); if (!isCodeBlock) { const inlineValue = String(children ?? "").trim(); + const normalizedInlinePath = inlineValue.replace(/\/+$/, ""); + const leafSegment = normalizedInlinePath.split("/").filter(Boolean).at(-1) ?? ""; + const isLikelyFolder = + inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes("."); const isLocalPath = - !!electronAPI && isVirtualFilePathToken(inlineValue) && !inlineValue.startsWith("//"); + !!electronAPI && + isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder; const displayLocalPath = inlineValue.replace(/^\/+/, ""); const searchSpaceIdParam = params?.search_space_id; const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) @@ -438,14 +483,31 @@ const defaultComponents = memoizeMarkdownComponents({ onClick={(event) => { event.preventDefault(); event.stopPropagation(); - openEditorPanel({ - kind: "local_file", - localFilePath: inlineValue, - title: inlineValue.split("/").pop() || inlineValue, - searchSpaceId: Number.isFinite(parsedSearchSpaceId) + void (async () => { + let resolvedLocalPath = inlineValue; + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) ? parsedSearchSpaceId - : undefined, - }); + : undefined; + if (electronAPI?.getAgentFilesystemMounts) { + try { + const mounts = (await electronAPI.getAgentFilesystemMounts( + resolvedSearchSpaceId + )) as AgentFilesystemMount[]; + resolvedLocalPath = normalizeLocalVirtualPathForEditor( + inlineValue, + mounts + ); + } catch { + // Fall back to the raw inline path if mount lookup fails. + } + } + openEditorPanel({ + kind: "local_file", + localFilePath: resolvedLocalPath, + title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, + searchSpaceId: resolvedSearchSpaceId, + }); + })(); }} title="Open in editor panel" > diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index a9fe886e1..9b1383d7f 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -47,6 +47,42 @@ interface EditorContent { const EDITABLE_DOCUMENT_TYPES = new Set(["FILE", "NOTE"]); type EditorRenderMode = "rich_markdown" | "source_code"; +type AgentFilesystemMount = { + mount: string; + rootPath: string; +}; + +function normalizeLocalVirtualPathForEditor( + candidatePath: string, + mounts: AgentFilesystemMount[] +): string { + const normalizedCandidate = candidatePath.trim().replace(/\\/g, "/").replace(/\/+/g, "/"); + if (!normalizedCandidate) return candidatePath; + const defaultMount = mounts[0]?.mount; + if (!defaultMount) { + return normalizedCandidate.startsWith("/") + ? normalizedCandidate + : `/${normalizedCandidate.replace(/^\/+/, "")}`; + } + + const mountNames = new Set(mounts.map((entry) => entry.mount)); + if (normalizedCandidate.startsWith("/")) { + const relative = normalizedCandidate.replace(/^\/+/, ""); + const [firstSegment] = relative.split("/", 1); + if (mountNames.has(firstSegment)) { + return `/${relative}`; + } + return `/${defaultMount}/${relative}`; + } + + const relative = normalizedCandidate.replace(/^\/+/, ""); + const [firstSegment] = relative.split("/", 1); + if (mountNames.has(firstSegment)) { + return `/${relative}`; + } + return `/${defaultMount}/${relative}`; +} + function EditorPanelSkeleton() { return (
@@ -100,6 +136,22 @@ export function EditorPanelContent({ const [displayTitle, setDisplayTitle] = useState(title || "Untitled"); const isLocalFileMode = kind === "local_file"; const editorRenderMode: EditorRenderMode = isLocalFileMode ? "source_code" : "rich_markdown"; + const resolveLocalVirtualPath = useCallback( + async (candidatePath: string): Promise => { + if (!electronAPI?.getAgentFilesystemMounts) { + return candidatePath; + } + try { + const mounts = (await electronAPI.getAgentFilesystemMounts( + searchSpaceId + )) as AgentFilesystemMount[]; + return normalizeLocalVirtualPathForEditor(candidatePath, mounts); + } catch { + return candidatePath; + } + }, + [electronAPI, searchSpaceId] + ); const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > LARGE_DOCUMENT_THRESHOLD; @@ -124,14 +176,15 @@ export function EditorPanelContent({ if (!electronAPI?.readAgentLocalFileText) { throw new Error("Local file editor is available only in desktop mode."); } + const resolvedLocalPath = await resolveLocalVirtualPath(localFilePath); const readResult = await electronAPI.readAgentLocalFileText( - localFilePath, + resolvedLocalPath, searchSpaceId ); if (!readResult.ok) { throw new Error(readResult.error || "Failed to read local file"); } - const inferredTitle = localFilePath.split("/").pop() || localFilePath; + const inferredTitle = resolvedLocalPath.split("/").pop() || resolvedLocalPath; const content: EditorContent = { document_id: -1, title: inferredTitle, @@ -195,7 +248,7 @@ export function EditorPanelContent({ doFetch().catch(() => {}); return () => controller.abort(); - }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId, title]); + }, [documentId, electronAPI, isLocalFileMode, localFilePath, resolveLocalVirtualPath, searchSpaceId, title]); useEffect(() => { return () => { @@ -239,9 +292,10 @@ export function EditorPanelContent({ if (!electronAPI?.writeAgentLocalFileText) { throw new Error("Local file editor is available only in desktop mode."); } + const resolvedLocalPath = await resolveLocalVirtualPath(localFilePath); const contentToSave = markdownRef.current; const writeResult = await electronAPI.writeAgentLocalFileText( - localFilePath, + resolvedLocalPath, contentToSave, searchSpaceId ); @@ -290,7 +344,7 @@ export function EditorPanelContent({ } finally { setSaving(false); } - }, [documentId, electronAPI, isLocalFileMode, localFilePath, searchSpaceId]); + }, [documentId, electronAPI, isLocalFileMode, localFilePath, resolveLocalVirtualPath, searchSpaceId]); const isEditableType = editorDoc ? (editorRenderMode === "source_code" || diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index add7cd2d9..d1146338d 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -86,7 +86,8 @@ export function LocalFilesystemBrowser({ const [mountByRootKey, setMountByRootKey] = useState>(new Map()); const [mountStatus, setMountStatus] = useState("idle"); const [mountRefreshInFlight, setMountRefreshInFlight] = useState(false); - const lastLoadedRootsSignatureRef = useRef(""); + const [reloadNonceByRoot, setReloadNonceByRoot] = useState>({}); + const lastLoadedSignatureByRootRef = useRef>(new Map()); const hasLoadedMountsOnceRef = useRef(false); const hasResolvedAtLeastOneRootRef = useRef(false); const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); @@ -107,18 +108,34 @@ export function LocalFilesystemBrowser({ } return; } - const rootsSignature = rootPaths - .map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) - .sort() - .join("|"); - const settingsSignature = `${searchSpaceId}:${rootsSignature}`; - if (settingsSignature === lastLoadedRootsSignatureRef.current) { + const rootEntries = rootPaths.map((rootPath) => ({ + rootPath, + rootKey: normalizeRootPathForLookup(rootPath, isWindowsPlatform), + })); + const activeRootKeys = new Set(rootEntries.map((entry) => entry.rootKey)); + for (const key of Array.from(lastLoadedSignatureByRootRef.current.keys())) { + if (!activeRootKeys.has(key)) { + lastLoadedSignatureByRootRef.current.delete(key); + } + } + const rootsToReload = rootEntries.filter(({ rootKey }) => { + const nonce = reloadNonceByRoot[rootKey] ?? 0; + const signature = `${searchSpaceId}:${rootKey}:${nonce}`; + return lastLoadedSignatureByRootRef.current.get(rootKey) !== signature; + }); + if (rootsToReload.length === 0) { return; } - lastLoadedRootsSignatureRef.current = settingsSignature; + for (const { rootKey } of rootsToReload) { + const nonce = reloadNonceByRoot[rootKey] ?? 0; + lastLoadedSignatureByRootRef.current.set( + rootKey, + `${searchSpaceId}:${rootKey}:${nonce}` + ); + } let cancelled = false; - for (const rootPath of rootPaths) { + for (const { rootPath } of rootsToReload) { setRootStateMap((prev) => ({ ...prev, [rootPath]: { @@ -130,7 +147,7 @@ export function LocalFilesystemBrowser({ } void Promise.all( - rootPaths.map(async (rootPath) => { + rootsToReload.map(async ({ rootPath }) => { try { const files = (await electronAPI.listAgentFilesystemFiles({ rootPath, @@ -164,6 +181,57 @@ export function LocalFilesystemBrowser({ return () => { cancelled = true; }; + }, [active, electronAPI, isWindowsPlatform, reloadNonceByRoot, rootPaths, searchSpaceId, supportedExtensions]); + + useEffect(() => { + if (active) return; + lastLoadedSignatureByRootRef.current.clear(); + }, [active]); + + useEffect(() => { + if (!electronAPI?.startAgentFilesystemTreeWatch) return; + if (!electronAPI?.stopAgentFilesystemTreeWatch) return; + if (!electronAPI?.onAgentFilesystemTreeDirty) return; + if (!active) return; + if (rootPaths.length === 0) { + void electronAPI.stopAgentFilesystemTreeWatch(searchSpaceId); + return; + } + + const unsubscribe = electronAPI.onAgentFilesystemTreeDirty((event) => { + if ((event.searchSpaceId ?? null) !== (searchSpaceId ?? null)) { + return; + } + const eventRootKey = normalizeRootPathForLookup(event.rootPath, isWindowsPlatform); + const knownRootKeys = new Set( + rootPaths.map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) + ); + if (!knownRootKeys.has(eventRootKey)) { + setReloadNonceByRoot((prev) => { + const next = { ...prev }; + for (const rootKey of knownRootKeys) { + next[rootKey] = (prev[rootKey] ?? 0) + 1; + } + return next; + }); + return; + } + setReloadNonceByRoot((prev) => ({ + ...prev, + [eventRootKey]: (prev[eventRootKey] ?? 0) + 1, + })); + }); + void electronAPI.startAgentFilesystemTreeWatch({ + searchSpaceId, + rootPaths, + excludePatterns: DEFAULT_EXCLUDE_PATTERNS, + fileExtensions: supportedExtensions, + }); + + return () => { + unsubscribe(); + void electronAPI.stopAgentFilesystemTreeWatch(searchSpaceId); + }; }, [active, electronAPI, isWindowsPlatform, rootPaths, searchSpaceId, supportedExtensions]); useEffect(() => { diff --git a/surfsense_web/types/window.d.ts b/surfsense_web/types/window.d.ts index d3356d4d1..5840d7a04 100644 --- a/surfsense_web/types/window.d.ts +++ b/surfsense_web/types/window.d.ts @@ -61,6 +61,21 @@ interface AgentFilesystemListOptions { fileExtensions?: string[] | null; } +interface AgentFilesystemTreeWatchOptions { + searchSpaceId?: number | null; + rootPaths: string[]; + excludePatterns?: string[] | null; + fileExtensions?: string[] | null; +} + +interface AgentFilesystemTreeDirtyEvent { + searchSpaceId: number | null; + reason: "watcher_event" | "safety_poll"; + rootPath: string; + changedPath: string | null; + timestamp: number; +} + interface LocalTextFileResult { ok: boolean; path: string; @@ -167,6 +182,13 @@ interface ElectronAPI { listAgentFilesystemFiles: ( options: AgentFilesystemListOptions ) => Promise; + startAgentFilesystemTreeWatch: ( + options: AgentFilesystemTreeWatchOptions + ) => Promise<{ ok: true }>; + stopAgentFilesystemTreeWatch: (searchSpaceId?: number | null) => Promise<{ ok: true }>; + onAgentFilesystemTreeDirty: ( + callback: (data: AgentFilesystemTreeDirtyEvent) => void + ) => () => void; setAgentFilesystemSettings: (settings: { mode?: AgentFilesystemMode; localRootPaths?: string[] | null; From 9cd4daa6b394e33b9a3325baa59c6244dce59d64 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 20:35:47 +0200 Subject: [PATCH 199/299] Add a single-session desktop window picker and route screenshot assist, region crop, and fullscreen capture through the cached frame. --- surfsense_desktop/scripts/build-electron.mjs | 6 + surfsense_desktop/src/ipc/channels.ts | 3 + surfsense_desktop/src/ipc/handlers.ts | 5 +- .../src/modules/screen-region-picker.ts | 174 ++++++++----- .../src/modules/screenshot-assist.ts | 14 +- .../src/modules/window-picker.ts | 244 ++++++++++++++++++ .../src/window-picker-preload.ts | 15 ++ 7 files changed, 396 insertions(+), 65 deletions(-) create mode 100644 surfsense_desktop/src/modules/window-picker.ts create mode 100644 surfsense_desktop/src/window-picker-preload.ts diff --git a/surfsense_desktop/scripts/build-electron.mjs b/surfsense_desktop/scripts/build-electron.mjs index 0c8f08d52..ca17e4c48 100644 --- a/surfsense_desktop/scripts/build-electron.mjs +++ b/surfsense_desktop/scripts/build-electron.mjs @@ -138,6 +138,12 @@ async function buildElectron() { outfile: 'dist/screen-region-preload.js', }); + await build({ + ...shared, + entryPoints: ['src/window-picker-preload.ts'], + outfile: 'dist/window-picker-preload.js', + }); + console.log('Electron build complete'); resolveStandaloneSymlinks(); } diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index 9f084af85..1007e3a37 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -14,6 +14,9 @@ export const IPC_CHANNELS = { CAPTURE_FULL_SCREEN: 'capture-full-screen', SCREEN_REGION_SUBMIT: 'screen-region:submit', SCREEN_REGION_CANCEL: 'screen-region:cancel', + WINDOW_PICK_LIST: 'window-pick:list', + WINDOW_PICK_SUBMIT: 'window-pick:submit', + WINDOW_PICK_CANCEL: 'window-pick:cancel', CHAT_SCREEN_CAPTURE: 'chat:screen-capture', // Folder sync channels FOLDER_SYNC_SELECT_FOLDER: 'folder-sync:select-folder', diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index 8361b9a38..d68d4a5bf 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -7,7 +7,7 @@ import { requestScreenRecording, restartApp, } from '../modules/permissions'; -import { captureCurrentDisplayDataUrl } from '../modules/screen-region-picker'; +import { pickOpenWindowCapture } from '../modules/window-picker'; import { selectFolder, addWatchedFolder, @@ -85,7 +85,8 @@ export function registerIpcHandlers(): void { requestScreenRecording(); return null; } - return captureCurrentDisplayDataUrl(); + const picked = await pickOpenWindowCapture(); + return picked?.dataUrl ?? null; }); // Folder sync handlers diff --git a/surfsense_desktop/src/modules/screen-region-picker.ts b/surfsense_desktop/src/modules/screen-region-picker.ts index cc9303040..1c4b77195 100644 --- a/surfsense_desktop/src/modules/screen-region-picker.ts +++ b/surfsense_desktop/src/modules/screen-region-picker.ts @@ -1,6 +1,17 @@ import { BrowserWindow, desktopCapturer, nativeImage, screen } from 'electron'; import path from 'path'; import { IPC_CHANNELS } from '../ipc/channels'; +function fitNativeImageToWorkArea(img: Electron.NativeImage, display: Electron.Display): Electron.NativeImage { + const wa = display.workArea; + const { width: iw, height: ih } = img.getSize(); + const scale = Math.min(1, wa.width / iw, wa.height / ih); + if (scale >= 1) return img; + return img.resize({ + width: Math.max(1, Math.floor(iw * scale)), + height: Math.max(1, Math.floor(ih * scale)), + quality: 'best', + }); +} // One getSources per pick; overlay and final crop share that bitmap (avoids a second portal session, e.g. Wayland). @@ -141,7 +152,7 @@ function buildInjectScript(dataUrl: string, iw: number, ih: number): string { })();`; } -export function pickScreenRegion(): Promise { +export function pickScreenRegion(opts?: { windowDataUrl?: string }): Promise { if (pickInProgress) return Promise.resolve(null); pickInProgress = true; @@ -175,6 +186,7 @@ export function pickScreenRegion(): Promise { }; let snapshot: DisplayCaptureSnapshot | null = null; + let cropSource: Electron.NativeImage | null = null; const onSubmit = ( _event: Electron.IpcMainEvent, @@ -185,17 +197,25 @@ export function pickScreenRegion(): Promise { finish(null); return; } - if (!snapshot) { + if (!snapshot || !cropSource) { finish(null); return; } try { - const full = nativeImage.createFromDataURL(snapshot.dataUrl); - const cropped = full.crop({ - x: Math.floor(rect.x), - y: Math.floor(rect.y), - width: Math.floor(rect.width), - height: Math.floor(rect.height), + const iw = snapshot.width; + const ih = snapshot.height; + const { width: cw, height: ch } = cropSource.getSize(); + const scaleX = cw / iw; + const scaleY = ch / ih; + const ox = Math.floor(rect.x * scaleX); + const oy = Math.floor(rect.y * scaleY); + const ow = Math.min(Math.floor(rect.width * scaleX), cw - ox); + const oh = Math.min(Math.floor(rect.height * scaleY), ch - oy); + const cropped = cropSource.crop({ + x: ox, + y: oy, + width: Math.max(1, ow), + height: Math.max(1, oh), }); finish(cropped.toDataURL()); } catch { @@ -214,66 +234,102 @@ export function pickScreenRegion(): Promise { } }; - void captureDisplaySnapshot(display) - .then((cap) => { + const openOverlay = ( + cap: DisplayCaptureSnapshot, + crop: Electron.NativeImage, + bounds: { x: number; y: number; width: number; height: number } + ) => { + snapshot = cap; + cropSource = crop; + + overlay = new BrowserWindow({ + x: bounds.x, + y: bounds.y, + width: bounds.width, + height: bounds.height, + frame: false, + transparent: true, + fullscreenable: false, + skipTaskbar: true, + alwaysOnTop: true, + focusable: true, + show: false, + autoHideMenuBar: true, + backgroundColor: '#00000000', + webPreferences: { + preload: path.join(__dirname, 'screen-region-preload.js'), + contextIsolation: true, + nodeIntegration: false, + sandbox: true, + }, + }); + + overlayWc = overlay.webContents; + overlayWc.on('before-input-event', onBeforeInput); + overlayWc.ipc.on(IPC_CHANNELS.SCREEN_REGION_SUBMIT, onSubmit); + overlayWc.ipc.on(IPC_CHANNELS.SCREEN_REGION_CANCEL, onCancel); + + overlay.setIgnoreMouseEvents(false); + overlay.loadURL( + 'data:text/html;charset=utf-8,' + + encodeURIComponent('') + ); + + overlay.on('closed', () => { + if (!settled) finish(null); + }); + + overlay.webContents.once('did-finish-load', () => { + if (!overlay || overlay.isDestroyed()) return; + overlay.webContents + .executeJavaScript(buildInjectScript(cap.dataUrl, cap.width, cap.height), true) + .then(() => { + overlay?.show(); + overlay?.focus(); + }) + .catch(() => { + finish(null); + }); + }); + }; + + void (async () => { + try { + if (opts?.windowDataUrl) { + const fullRes = nativeImage.createFromDataURL(opts.windowDataUrl); + if (fullRes.isEmpty()) { + finish(null); + return; + } + const fitted = fitNativeImageToWorkArea(fullRes, display); + const fw = fitted.getSize().width; + const fh = fitted.getSize().height; + const wa = display.workArea; + const x = wa.x + Math.floor((wa.width - fw) / 2); + const y = wa.y + Math.floor((wa.height - fh) / 2); + openOverlay( + { dataUrl: fitted.toDataURL(), width: fw, height: fh }, + fullRes, + { x, y, width: fw, height: fh } + ); + return; + } + + const cap = await captureDisplaySnapshot(display); if (!cap) { finish(null); return; } - snapshot = cap; - - overlay = new BrowserWindow({ + const crop = nativeImage.createFromDataURL(cap.dataUrl); + openOverlay(cap, crop, { x: display.bounds.x, y: display.bounds.y, width: display.bounds.width, height: display.bounds.height, - frame: false, - transparent: true, - fullscreenable: false, - skipTaskbar: true, - alwaysOnTop: true, - focusable: true, - show: false, - autoHideMenuBar: true, - backgroundColor: '#00000000', - webPreferences: { - preload: path.join(__dirname, 'screen-region-preload.js'), - contextIsolation: true, - nodeIntegration: false, - sandbox: true, - }, }); - - overlayWc = overlay.webContents; - overlayWc.on('before-input-event', onBeforeInput); - overlayWc.ipc.on(IPC_CHANNELS.SCREEN_REGION_SUBMIT, onSubmit); - overlayWc.ipc.on(IPC_CHANNELS.SCREEN_REGION_CANCEL, onCancel); - - overlay.setIgnoreMouseEvents(false); - overlay.loadURL( - 'data:text/html;charset=utf-8,' + - encodeURIComponent('') - ); - - overlay.on('closed', () => { - if (!settled) finish(null); - }); - - overlay.webContents.once('did-finish-load', () => { - if (!overlay || overlay.isDestroyed()) return; - overlay.webContents - .executeJavaScript(buildInjectScript(cap.dataUrl, cap.width, cap.height), true) - .then(() => { - overlay?.show(); - overlay?.focus(); - }) - .catch(() => { - finish(null); - }); - }); - }) - .catch(() => { + } catch { finish(null); - }); + } + })(); }); } diff --git a/surfsense_desktop/src/modules/screenshot-assist.ts b/surfsense_desktop/src/modules/screenshot-assist.ts index 2500bf1d5..34fd0f489 100644 --- a/surfsense_desktop/src/modules/screenshot-assist.ts +++ b/surfsense_desktop/src/modules/screenshot-assist.ts @@ -1,19 +1,25 @@ import { IPC_CHANNELS } from '../ipc/channels'; import { trackEvent } from './analytics'; import { pickScreenRegion } from './screen-region-picker'; +import { pickOpenWindowCapture } from './window-picker'; import { getMainWindow, showMainWindow } from './window'; import { hasScreenRecordingPermission, requestScreenRecording } from './permissions'; export async function runScreenshotAssistShortcut(): Promise { - showMainWindow('shortcut'); - await new Promise((r) => setTimeout(r, 400)); if (!hasScreenRecordingPermission()) { requestScreenRecording(); return; } - const url = await pickScreenRegion(); + + const picked = await pickOpenWindowCapture(); + if (!picked) return; + + const url = await pickScreenRegion({ windowDataUrl: picked.dataUrl }); + if (!url) return; + + showMainWindow('shortcut'); const mw = getMainWindow(); - if (url && mw && !mw.isDestroyed()) { + if (mw && !mw.isDestroyed()) { mw.webContents.send(IPC_CHANNELS.CHAT_SCREEN_CAPTURE, url); trackEvent('desktop_screenshot_assist_region_to_chat', {}); } diff --git a/surfsense_desktop/src/modules/window-picker.ts b/surfsense_desktop/src/modules/window-picker.ts new file mode 100644 index 000000000..0e8505bcb --- /dev/null +++ b/surfsense_desktop/src/modules/window-picker.ts @@ -0,0 +1,244 @@ +import { BrowserWindow, desktopCapturer, ipcMain, screen } from 'electron'; +import path from 'path'; +import { IPC_CHANNELS } from '../ipc/channels'; + +let pickInProgress = false; + +const PREVIEW_THUMB = { width: 280, height: 180 } as const; + +function maxCaptureThumbSize(): { width: number; height: number } { + const d = screen.getPrimaryDisplay(); + const sf = d.scaleFactor || 1; + const w = Math.min(3840, Math.max(1280, Math.round(d.size.width * sf))); + const h = Math.min(2160, Math.max(720, Math.round(d.size.height * sf))); + return { width: w, height: h }; +} + +function isDesktopWindowSourceId(s: string): boolean { + return typeof s === 'string' && s.startsWith('window:'); +} + +export type PickedWindowResult = { + sourceId: string; + /** Same pixels as the one `desktopCapturer` snapshot (max thumbnail size). */ + dataUrl: string; +}; + +function buildPickerInjectScript(): string { + return `(async function () { + const api = window.surfsenseWindowPick; + if (!api) return; + const items = await api.list(); + document.body.style.cssText = + 'margin:0;font-family:system-ui,-apple-system,sans-serif;background:#0f172a;color:#e2e8f0;min-height:100vh;padding:16px;box-sizing:border-box;'; + const top = document.createElement('div'); + top.style.cssText = + 'display:flex;justify-content:space-between;align-items:center;margin-bottom:12px;flex-wrap:wrap;gap:8px;'; + const t = document.createElement('strong'); + t.textContent = 'Open windows'; + const hint = document.createElement('span'); + hint.style.cssText = 'opacity:0.75;font-size:13px;'; + hint.textContent = 'Click a window · Esc to cancel'; + top.appendChild(t); + top.appendChild(hint); + document.body.appendChild(top); + if (!items || !items.length) { + const p = document.createElement('p'); + p.style.cssText = 'line-height:1.5;max-width:42rem;'; + p.textContent = + 'No windows were returned by the system. On Linux, allow screen capture when prompted. If other apps are open, try again.'; + document.body.appendChild(p); + return; + } + const grid = document.createElement('div'); + grid.style.cssText = + 'display:grid;grid-template-columns:repeat(auto-fill,minmax(200px,1fr));gap:12px;max-height:calc(100vh - 88px);overflow:auto;padding-bottom:8px;'; + for (const it of items) { + const card = document.createElement('button'); + card.type = 'button'; + card.style.cssText = + 'text-align:left;background:#1e293b;border:1px solid #334155;border-radius:8px;padding:8px;cursor:pointer;color:inherit;'; + card.addEventListener('mouseenter', function () { + card.style.borderColor = '#38bdf8'; + }); + card.addEventListener('mouseleave', function () { + card.style.borderColor = '#334155'; + }); + const img = document.createElement('img'); + img.alt = ''; + img.src = + it.thumbDataUrl || + 'data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///ywAAAAAAQABAAACAUwAOw=='; + img.style.cssText = + 'width:100%;height:100px;object-fit:cover;border-radius:4px;background:#000;display:block;'; + const cap = document.createElement('div'); + cap.textContent = it.name || '(untitled)'; + cap.style.cssText = + 'margin-top:6px;font-size:12px;line-height:1.35;overflow:hidden;text-overflow:ellipsis;display:-webkit-box;-webkit-line-clamp:2;-webkit-box-orient:vertical;'; + card.appendChild(img); + card.appendChild(cap); + card.addEventListener('click', function () { + api.submit(it.id); + }); + grid.appendChild(card); + } + document.body.appendChild(grid); + window.addEventListener('keydown', function (e) { + if (e.key === 'Escape') api.cancel(); + }); + })();`; +} + +/** + * One OS / Chromium capture session: `getSources` runs once (important on Wayland / + * PipeWire so the portal is not opened again for the same flow). Opens our grid to + * choose a window; resolves with the chosen snapshot for region or full-frame use. + */ +export function pickOpenWindowCapture(): Promise { + if (pickInProgress) return Promise.resolve(null); + pickInProgress = true; + + return new Promise((resolve) => { + let settled = false; + let picker: BrowserWindow | null = null; + let pickerWc: Electron.WebContents | null = null; + /** Filled once before the grid runs — reused for list + final image (no second getSources). */ + let sessionSources: Electron.DesktopCapturerSource[] = []; + + const finish = (result: PickedWindowResult | null) => { + if (settled) return; + settled = true; + pickInProgress = false; + ipcMain.removeHandler(IPC_CHANNELS.WINDOW_PICK_LIST); + const wc = pickerWc; + pickerWc = null; + if (wc && !wc.isDestroyed()) { + wc.removeListener('before-input-event', onBeforeInput); + wc.ipc.removeListener(IPC_CHANNELS.WINDOW_PICK_SUBMIT, onSubmit); + wc.ipc.removeListener(IPC_CHANNELS.WINDOW_PICK_CANCEL, onCancel); + } + if (picker && !picker.isDestroyed()) { + picker.removeAllListeners('closed'); + picker.close(); + } + picker = null; + resolve(result); + }; + + const onSubmit = (_event: Electron.IpcMainEvent, sourceId: string) => { + if (settled || !picker || picker.isDestroyed()) return; + if (!isDesktopWindowSourceId(sourceId)) { + finish(null); + return; + } + const hit = sessionSources.find((s) => s.id === sourceId); + if (!hit || hit.thumbnail.isEmpty()) { + finish(null); + return; + } + finish({ sourceId, dataUrl: hit.thumbnail.toDataURL() }); + }; + + const onCancel = () => { + if (settled || !picker || picker.isDestroyed()) return; + finish(null); + }; + + const onBeforeInput = (_event: Electron.Event, input: Electron.Input) => { + if (input.type === 'keyDown' && input.key === 'Escape') { + finish(null); + } + }; + + ipcMain.handle(IPC_CHANNELS.WINDOW_PICK_LIST, async () => { + return sessionSources.map((s, i) => { + let thumbDataUrl = ''; + if (!s.thumbnail.isEmpty()) { + try { + const sm = s.thumbnail.resize({ + width: PREVIEW_THUMB.width, + height: PREVIEW_THUMB.height, + quality: 'good', + }); + thumbDataUrl = sm.toDataURL(); + } catch { + thumbDataUrl = s.thumbnail.toDataURL(); + } + } + return { + id: s.id, + name: (s.name || '').trim() || `Window ${i + 1}`, + thumbDataUrl, + }; + }); + }); + + picker = new BrowserWindow({ + width: 760, + height: 560, + show: false, + center: true, + autoHideMenuBar: true, + title: 'SurfSense — choose window', + webPreferences: { + preload: path.join(__dirname, 'window-picker-preload.js'), + contextIsolation: true, + nodeIntegration: false, + sandbox: true, + }, + }); + + pickerWc = picker.webContents; + + pickerWc.on('before-input-event', onBeforeInput); + pickerWc.ipc.on(IPC_CHANNELS.WINDOW_PICK_SUBMIT, onSubmit); + pickerWc.ipc.on(IPC_CHANNELS.WINDOW_PICK_CANCEL, onCancel); + + picker.on('closed', () => { + if (!settled) finish(null); + }); + + picker + .loadURL( + 'data:text/html;charset=utf-8,' + + encodeURIComponent('') + ) + .catch(() => finish(null)); + + picker.webContents.once('did-finish-load', () => { + void (async () => { + if (!picker || picker.isDestroyed()) return; + let selfId = ''; + try { + selfId = picker.getMediaSourceId(); + } catch { + selfId = ''; + } + try { + const { width, height } = maxCaptureThumbSize(); + const sources = await desktopCapturer.getSources({ + types: ['window'], + thumbnailSize: { width, height }, + fetchWindowIcons: false, + }); + sessionSources = sources.filter((s) => !(selfId && s.id === selfId)); + } catch { + sessionSources = []; + } + if (sessionSources.length === 1) { + const only = sessionSources[0]; + if (!only.thumbnail.isEmpty()) { + finish({ sourceId: only.id, dataUrl: only.thumbnail.toDataURL() }); + return; + } + } + try { + await picker.webContents.executeJavaScript(buildPickerInjectScript(), true); + if (!picker.isDestroyed()) picker.show(); + } catch { + finish(null); + } + })(); + }); + }); +} diff --git a/surfsense_desktop/src/window-picker-preload.ts b/surfsense_desktop/src/window-picker-preload.ts new file mode 100644 index 000000000..9b582cede --- /dev/null +++ b/surfsense_desktop/src/window-picker-preload.ts @@ -0,0 +1,15 @@ +import { contextBridge, ipcRenderer } from 'electron'; +import { IPC_CHANNELS } from './ipc/channels'; + +contextBridge.exposeInMainWorld('surfsenseWindowPick', { + list: () => + ipcRenderer.invoke(IPC_CHANNELS.WINDOW_PICK_LIST) as Promise< + { id: string; name: string; thumbDataUrl: string }[] + >, + submit: (sourceId: string) => { + ipcRenderer.send(IPC_CHANNELS.WINDOW_PICK_SUBMIT, sourceId); + }, + cancel: () => { + ipcRenderer.send(IPC_CHANNELS.WINDOW_PICK_CANCEL); + }, +}); From d4caae6de96274917fb383b3d68473071b2ed5a4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 27 Apr 2026 20:39:03 +0200 Subject: [PATCH 200/299] Move desktop screen capture into modules/screen-capture and align preload build paths and imports. --- surfsense_desktop/scripts/build-electron.mjs | 8 ++++---- surfsense_desktop/src/ipc/handlers.ts | 2 +- surfsense_desktop/src/modules/screen-capture/index.ts | 7 +++++++ .../modules/{ => screen-capture}/screen-region-picker.ts | 4 ++-- .../{ => modules/screen-capture}/screen-region-preload.ts | 2 +- .../src/modules/{ => screen-capture}/screenshot-assist.ts | 8 ++++---- .../{ => modules/screen-capture}/window-picker-preload.ts | 2 +- .../src/modules/{ => screen-capture}/window-picker.ts | 4 ++-- surfsense_desktop/src/modules/tray.ts | 2 +- 9 files changed, 23 insertions(+), 16 deletions(-) create mode 100644 surfsense_desktop/src/modules/screen-capture/index.ts rename surfsense_desktop/src/modules/{ => screen-capture}/screen-region-picker.ts (98%) rename surfsense_desktop/src/{ => modules/screen-capture}/screen-region-preload.ts (87%) rename surfsense_desktop/src/modules/{ => screen-capture}/screenshot-assist.ts (80%) rename surfsense_desktop/src/{ => modules/screen-capture}/window-picker-preload.ts (89%) rename surfsense_desktop/src/modules/{ => screen-capture}/window-picker.ts (98%) diff --git a/surfsense_desktop/scripts/build-electron.mjs b/surfsense_desktop/scripts/build-electron.mjs index ca17e4c48..75a3cdf61 100644 --- a/surfsense_desktop/scripts/build-electron.mjs +++ b/surfsense_desktop/scripts/build-electron.mjs @@ -134,14 +134,14 @@ async function buildElectron() { await build({ ...shared, - entryPoints: ['src/screen-region-preload.ts'], - outfile: 'dist/screen-region-preload.js', + entryPoints: ['src/modules/screen-capture/screen-region-preload.ts'], + outfile: 'dist/modules/screen-capture/screen-region-preload.js', }); await build({ ...shared, - entryPoints: ['src/window-picker-preload.ts'], - outfile: 'dist/window-picker-preload.js', + entryPoints: ['src/modules/screen-capture/window-picker-preload.ts'], + outfile: 'dist/modules/screen-capture/window-picker-preload.js', }); console.log('Electron build complete'); diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index d68d4a5bf..b524a91a1 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -7,7 +7,7 @@ import { requestScreenRecording, restartApp, } from '../modules/permissions'; -import { pickOpenWindowCapture } from '../modules/window-picker'; +import { pickOpenWindowCapture } from '../modules/screen-capture'; import { selectFolder, addWatchedFolder, diff --git a/surfsense_desktop/src/modules/screen-capture/index.ts b/surfsense_desktop/src/modules/screen-capture/index.ts new file mode 100644 index 000000000..6c1c75509 --- /dev/null +++ b/surfsense_desktop/src/modules/screen-capture/index.ts @@ -0,0 +1,7 @@ +/** + * Window capture for Screenshot Assist and chat fullscreen: single-session + * desktopCapturer, region overlay, and shortcut entry point. + */ +export { pickOpenWindowCapture, type PickedWindowResult } from './window-picker'; +export { pickScreenRegion, captureCurrentDisplayDataUrl } from './screen-region-picker'; +export { runScreenshotAssistShortcut } from './screenshot-assist'; diff --git a/surfsense_desktop/src/modules/screen-region-picker.ts b/surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts similarity index 98% rename from surfsense_desktop/src/modules/screen-region-picker.ts rename to surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts index 1c4b77195..fd771b0f7 100644 --- a/surfsense_desktop/src/modules/screen-region-picker.ts +++ b/surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts @@ -1,6 +1,6 @@ import { BrowserWindow, desktopCapturer, nativeImage, screen } from 'electron'; import path from 'path'; -import { IPC_CHANNELS } from '../ipc/channels'; +import { IPC_CHANNELS } from '../../ipc/channels'; function fitNativeImageToWorkArea(img: Electron.NativeImage, display: Electron.Display): Electron.NativeImage { const wa = display.workArea; const { width: iw, height: ih } = img.getSize(); @@ -257,7 +257,7 @@ export function pickScreenRegion(opts?: { windowDataUrl?: string }): Promise { diff --git a/surfsense_desktop/src/modules/screenshot-assist.ts b/surfsense_desktop/src/modules/screen-capture/screenshot-assist.ts similarity index 80% rename from surfsense_desktop/src/modules/screenshot-assist.ts rename to surfsense_desktop/src/modules/screen-capture/screenshot-assist.ts index 34fd0f489..171b98a57 100644 --- a/surfsense_desktop/src/modules/screenshot-assist.ts +++ b/surfsense_desktop/src/modules/screen-capture/screenshot-assist.ts @@ -1,9 +1,9 @@ -import { IPC_CHANNELS } from '../ipc/channels'; -import { trackEvent } from './analytics'; +import { IPC_CHANNELS } from '../../ipc/channels'; +import { trackEvent } from '../analytics'; import { pickScreenRegion } from './screen-region-picker'; import { pickOpenWindowCapture } from './window-picker'; -import { getMainWindow, showMainWindow } from './window'; -import { hasScreenRecordingPermission, requestScreenRecording } from './permissions'; +import { getMainWindow, showMainWindow } from '../window'; +import { hasScreenRecordingPermission, requestScreenRecording } from '../permissions'; export async function runScreenshotAssistShortcut(): Promise { if (!hasScreenRecordingPermission()) { diff --git a/surfsense_desktop/src/window-picker-preload.ts b/surfsense_desktop/src/modules/screen-capture/window-picker-preload.ts similarity index 89% rename from surfsense_desktop/src/window-picker-preload.ts rename to surfsense_desktop/src/modules/screen-capture/window-picker-preload.ts index 9b582cede..dd0acd81e 100644 --- a/surfsense_desktop/src/window-picker-preload.ts +++ b/surfsense_desktop/src/modules/screen-capture/window-picker-preload.ts @@ -1,5 +1,5 @@ import { contextBridge, ipcRenderer } from 'electron'; -import { IPC_CHANNELS } from './ipc/channels'; +import { IPC_CHANNELS } from '../../ipc/channels'; contextBridge.exposeInMainWorld('surfsenseWindowPick', { list: () => diff --git a/surfsense_desktop/src/modules/window-picker.ts b/surfsense_desktop/src/modules/screen-capture/window-picker.ts similarity index 98% rename from surfsense_desktop/src/modules/window-picker.ts rename to surfsense_desktop/src/modules/screen-capture/window-picker.ts index 0e8505bcb..b66e23c5c 100644 --- a/surfsense_desktop/src/modules/window-picker.ts +++ b/surfsense_desktop/src/modules/screen-capture/window-picker.ts @@ -1,6 +1,6 @@ import { BrowserWindow, desktopCapturer, ipcMain, screen } from 'electron'; import path from 'path'; -import { IPC_CHANNELS } from '../ipc/channels'; +import { IPC_CHANNELS } from '../../ipc/channels'; let pickInProgress = false; @@ -181,7 +181,7 @@ export function pickOpenWindowCapture(): Promise { autoHideMenuBar: true, title: 'SurfSense — choose window', webPreferences: { - preload: path.join(__dirname, 'window-picker-preload.js'), + preload: path.join(__dirname, 'modules', 'screen-capture', 'window-picker-preload.js'), contextIsolation: true, nodeIntegration: false, sandbox: true, diff --git a/surfsense_desktop/src/modules/tray.ts b/surfsense_desktop/src/modules/tray.ts index 07b53bafb..5fb1acbdf 100644 --- a/surfsense_desktop/src/modules/tray.ts +++ b/surfsense_desktop/src/modules/tray.ts @@ -1,7 +1,7 @@ import { app, globalShortcut, Menu, nativeImage, Tray, type NativeImage } from 'electron'; import path from 'path'; import { runGeneralAssistShortcut } from './general-assist'; -import { runScreenshotAssistShortcut } from './screenshot-assist'; +import { runScreenshotAssistShortcut } from './screen-capture'; import { showMainWindow } from './window'; import { getShortcuts } from './shortcuts'; import { trackEvent } from './analytics'; From 7bcb6306c5be26da5b7a094340c001e8fd759303 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 00:45:07 +0530 Subject: [PATCH 201/299] refactor(filesystem): streamline filesystem operations by removing cursor-based pagination and enhancing path normalization methods --- .../agents/new_chat/middleware/filesystem.py | 81 ++------ .../middleware/local_folder_backend.py | 178 +++++------------- .../multi_root_local_folder_backend.py | 49 +---- 3 files changed, 69 insertions(+), 239 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index d7bb339bd..3622bbcdf 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -28,9 +28,6 @@ from langgraph.types import Command from sqlalchemy import delete, select from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( - MultiRootLocalFolderBackend, -) from app.agents.new_chat.sandbox import ( _evict_sandbox_cache, delete_sandbox, @@ -152,21 +149,19 @@ Notes: - Cross-mount moves are not supported. """ -SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively with cursor pagination. +SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. Use this in desktop local-folder mode to discover nested files at scale. Args: - path: absolute mount-prefixed path (e.g., //src) or "/" for mount roots. - max_depth: recursion depth limit (default 8). -- page_size: number of entries to return per page (max 1000). -- cursor: opaque continuation token from a previous call. +- page_size: maximum number of entries returned (max 1000). - include_files/include_dirs: filter returned entry types. Returns JSON with: - entries: [{path, is_dir, size, modified_at, depth}] -- next_cursor: continuation token or null -- has_more: whether additional pages exist +- truncated: true when additional entries were omitted due to page_size """ SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern. @@ -251,13 +246,13 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: system_prompt += ( "\n- move_file: move or rename files/folders in local-folder mode." - "\n- list_tree: recursively list nested local paths with cursor pagination." + "\n- list_tree: recursively list nested local paths in one bounded response." "\n\n## Local Folder Mode" "\n\nThis chat is running in desktop local-folder mode." " Keep all file operations local. Do not use save_document." " Always use mount-prefixed absolute paths like //file.ext." " If you are unsure which mounts are available, call ls('/') first." - " For big trees: use list_tree pages, then grep, then read_file." + " For big trees: use list_tree, then grep, then read_file." ) super().__init__( @@ -812,35 +807,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): """Only cloud mode persists file content to Document/Chunk tables.""" return self._filesystem_mode == FilesystemMode.CLOUD - def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str: - backend = self._get_backend(runtime) - if isinstance(backend, MultiRootLocalFolderBackend): - return f"/{backend.default_mount()}" - return "" - - def _normalize_local_mount_path( - self, candidate: str, runtime: ToolRuntime[None, FilesystemState] - ) -> str: - backend = self._get_backend(runtime) - mount_prefix = self._default_mount_prefix(runtime) - normalized_candidate = re.sub(r"/+", "/", candidate.strip().replace("\\", "/")) - if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend): - if normalized_candidate.startswith("/"): - return normalized_candidate - return f"/{normalized_candidate.lstrip('/')}" - - mount_names = set(backend.list_mounts()) - if normalized_candidate.startswith("/"): - first_segment = normalized_candidate.lstrip("/").split("/", 1)[0] - if first_segment in mount_names: - return normalized_candidate - return f"{mount_prefix}{normalized_candidate}" - - relative = normalized_candidate.lstrip("/") - first_segment = relative.split("/", 1)[0] - if first_segment in mount_names: - return f"/{relative}" - return f"{mount_prefix}/{relative}" + @staticmethod + def _normalize_absolute_path(candidate: str) -> str: + normalized = re.sub(r"/+", "/", candidate.strip().replace("\\", "/")) + if not normalized: + return "/" + if normalized.startswith("/"): + return normalized + return f"/{normalized.lstrip('/')}" def _get_contract_suggested_path( self, runtime: ToolRuntime[None, FilesystemState] @@ -848,14 +822,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): contract = runtime.state.get("file_operation_contract") or {} suggested = contract.get("suggested_path") if isinstance(suggested, str) and suggested.strip(): - cleaned = suggested.strip() - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(cleaned, runtime) - return cleaned - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - mount_prefix = self._default_mount_prefix(runtime) - if mount_prefix: - return f"{mount_prefix}/notes.md" + return self._normalize_absolute_path(suggested) return "/notes.md" def _resolve_write_target_path( @@ -867,7 +834,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not candidate: return self._get_contract_suggested_path(runtime) if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(candidate, runtime) + return self._normalize_absolute_path(candidate) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate @@ -881,7 +848,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not candidate: return "" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(candidate, runtime) + return self._normalize_absolute_path(candidate) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate @@ -895,7 +862,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if candidate == "/": return "/" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_local_mount_path(candidate, runtime) + return self._normalize_absolute_path(candidate) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate @@ -1136,12 +1103,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = 8, page_size: Annotated[ int, - "Number of entries to return per page. Defaults to 500 (max 1000).", + "Maximum number of entries to return. Defaults to 500 (max 1000).", ] = 500, - cursor: Annotated[ - str | None, - "Opaque cursor from a previous list_tree call.", - ] = None, include_files: Annotated[ bool, "Whether file entries should be included.", @@ -1171,7 +1134,6 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): validated_path, max_depth=max_depth, page_size=page_size, - cursor=cursor, include_files=include_files, include_dirs=include_dirs, ) @@ -1193,12 +1155,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = 8, page_size: Annotated[ int, - "Number of entries to return per page. Defaults to 500 (max 1000).", + "Maximum number of entries to return. Defaults to 500 (max 1000).", ] = 500, - cursor: Annotated[ - str | None, - "Opaque cursor from a previous list_tree call.", - ] = None, include_files: Annotated[ bool, "Whether file entries should be included.", @@ -1228,7 +1186,6 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): validated_path, max_depth=max_depth, page_size=page_size, - cursor=cursor, include_files=include_files, include_dirs=include_dirs, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index ef6a1657d..4f149a756 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -9,9 +9,7 @@ import threading from collections import deque from contextlib import ExitStack from pathlib import Path -from time import time from typing import Any -from uuid import uuid4 from deepagents.backends.protocol import ( EditResult, @@ -43,8 +41,6 @@ class LocalFolderBackend: self._root = root self._locks: dict[str, threading.Lock] = {} self._locks_mu = threading.Lock() - self._tree_sessions: dict[str, dict[str, Any]] = {} - self._tree_sessions_ttl_s = 900 def _lock_for(self, path: str) -> threading.Lock: with self._locks_mu: @@ -89,16 +85,6 @@ class LocalFolderBackend: def _clamp_page_size(page_size: int) -> int: return max(1, min(page_size, 1000)) - def _prune_expired_tree_sessions(self) -> None: - now = time() - expired = [ - cursor - for cursor, session in self._tree_sessions.items() - if now - float(session.get("last_accessed_at", now)) > self._tree_sessions_ttl_s - ] - for cursor in expired: - self._tree_sessions.pop(cursor, None) - def _read_dir_entries(self, directory_path: str) -> list[dict[str, Any]]: directory = Path(directory_path) try: @@ -206,148 +192,82 @@ class LocalFolderBackend: *, max_depth: int | None = 8, page_size: int = 500, - cursor: str | None = None, include_files: bool = True, include_dirs: bool = True, ) -> dict[str, Any]: - self._prune_expired_tree_sessions() if not include_files and not include_dirs: return { "entries": [], - "next_cursor": None, - "has_more": False, "truncated": False, } normalized_depth = None if max_depth is None else max(0, int(max_depth)) page_limit = self._clamp_page_size(int(page_size)) - now = time() - - if cursor: - session = self._tree_sessions.get(cursor) - if not session: - return {"error": "Invalid or expired cursor"} - if ( - session.get("path") != path - or session.get("max_depth") != normalized_depth - or session.get("include_files") != include_files - or session.get("include_dirs") != include_dirs - ): - return {"error": "Cursor options do not match request options"} - state = session - else: - try: - start = self._resolve_virtual(path, allow_root=True) - except ValueError: - return {"error": f"Error: invalid path '{path}'"} - if not start.exists(): - return {"error": f"Error: path '{path}' not found"} - if start.is_file(): - stat_result = start.stat() - if include_files: - return { - "entries": [ - { - "path": self._to_virtual(start, self._root), - "is_dir": False, - "size": stat_result.st_size, - "modified_at": str(stat_result.st_mtime), - "depth": 0, - } - ], - "next_cursor": None, - "has_more": False, - "truncated": False, - } + try: + start = self._resolve_virtual(path, allow_root=True) + except ValueError: + return {"error": f"Error: invalid path '{path}'"} + if not start.exists(): + return {"error": f"Error: path '{path}' not found"} + if start.is_file(): + stat_result = start.stat() + if include_files: return { - "entries": [], - "next_cursor": None, - "has_more": False, + "entries": [ + { + "path": self._to_virtual(start, self._root), + "is_dir": False, + "size": stat_result.st_size, + "modified_at": str(stat_result.st_mtime), + "depth": 0, + } + ], "truncated": False, } - state = { - "path": path, - "max_depth": normalized_depth, - "include_files": include_files, - "include_dirs": include_dirs, - "pending_dirs": deque([(str(start), 0)]), - "active_dir": None, - "active_depth": 0, - "active_entries": [], - "active_index": 0, + return { + "entries": [], + "truncated": False, } + pending_dirs: deque[tuple[str, int]] = deque([(str(start), 0)]) entries: list[dict[str, Any]] = [] truncated = False - while len(entries) < page_limit: - active_entries = state.get("active_entries", []) - active_index = int(state.get("active_index", 0)) - if active_index >= len(active_entries): - pending_dirs = state.get("pending_dirs", []) - if not pending_dirs: - state["active_entries"] = [] - state["active_index"] = 0 - break - next_dir_path, next_depth = pending_dirs.popleft() - state["active_dir"] = next_dir_path - state["active_depth"] = next_depth - state["active_entries"] = self._read_dir_entries(next_dir_path) - state["active_index"] = 0 - active_entries = state["active_entries"] - active_index = 0 - - if active_index >= len(active_entries): - continue - - item = active_entries[active_index] - state["active_index"] = active_index + 1 - item_depth = int(state.get("active_depth", 0)) + 1 - if normalized_depth is not None and item_depth > normalized_depth: - continue - if item["is_dir"]: - if normalized_depth is None or item_depth <= normalized_depth: - state["pending_dirs"].append((item["absolute_path"], item_depth)) - if include_dirs: + while pending_dirs and not truncated: + next_dir_path, next_depth = pending_dirs.popleft() + active_entries = self._read_dir_entries(next_dir_path) + for item in active_entries: + item_depth = next_depth + 1 + if normalized_depth is not None and item_depth > normalized_depth: + continue + if item["is_dir"]: + if normalized_depth is None or item_depth <= normalized_depth: + pending_dirs.append((item["absolute_path"], item_depth)) + if include_dirs: + entries.append( + { + "path": item["path"], + "is_dir": True, + "size": 0, + "modified_at": item["modified_at"], + "depth": item_depth, + } + ) + elif include_files: entries.append( { "path": item["path"], - "is_dir": True, - "size": 0, + "is_dir": False, + "size": item["size"], "modified_at": item["modified_at"], "depth": item_depth, } ) - elif include_files: - entries.append( - { - "path": item["path"], - "is_dir": False, - "size": item["size"], - "modified_at": item["modified_at"], - "depth": item_depth, - } - ) - - if len(entries) >= page_limit: - truncated = True - break - - has_more = bool(state.get("pending_dirs")) or ( - int(state.get("active_index", 0)) < len(state.get("active_entries", [])) - ) - if has_more: - next_cursor = cursor or uuid4().hex - state["last_accessed_at"] = now - self._tree_sessions[next_cursor] = state - else: - next_cursor = None - if cursor: - self._tree_sessions.pop(cursor, None) + if len(entries) >= page_limit: + truncated = True + break return { "entries": entries, - "next_cursor": next_cursor, - "has_more": has_more, "truncated": truncated, } @@ -357,7 +277,6 @@ class LocalFolderBackend: *, max_depth: int | None = 8, page_size: int = 500, - cursor: str | None = None, include_files: bool = True, include_dirs: bool = True, ) -> dict[str, Any]: @@ -366,7 +285,6 @@ class LocalFolderBackend: path, max_depth=max_depth, page_size=page_size, - cursor=cursor, include_files=include_files, include_dirs=include_dirs, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 6760d76f0..82914f9ce 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -3,8 +3,6 @@ from __future__ import annotations import asyncio -import base64 -import json from pathlib import Path from typing import Any @@ -109,28 +107,6 @@ class MultiRootLocalFolderBackend: for mount in self._mount_order ] - @staticmethod - def _encode_tree_cursor(mount: str, local_cursor: str) -> str: - payload = json.dumps( - {"mount": mount, "cursor": local_cursor}, - separators=(",", ":"), - ).encode("utf-8") - return base64.urlsafe_b64encode(payload).decode("ascii") - - @staticmethod - def _decode_tree_cursor(cursor: str) -> tuple[str, str]: - try: - padded = cursor + "=" * ((4 - len(cursor) % 4) % 4) - data = base64.urlsafe_b64decode(padded.encode("ascii")) - parsed = json.loads(data.decode("utf-8")) - except Exception as exc: - raise ValueError("Invalid cursor") from exc - mount = parsed.get("mount") - local_cursor = parsed.get("cursor") - if not isinstance(mount, str) or not isinstance(local_cursor, str): - raise ValueError("Invalid cursor") - return mount, local_cursor - def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]: transformed: list[FileInfo] = [] for info in infos: @@ -162,11 +138,10 @@ class MultiRootLocalFolderBackend: *, max_depth: int | None = 8, page_size: int = 500, - cursor: str | None = None, include_files: bool = True, include_dirs: bool = True, ) -> dict[str, Any]: - if path == "/" and not cursor: + if path == "/": entries = [ { "path": f"/{mount}", @@ -179,20 +154,11 @@ class MultiRootLocalFolderBackend: ] return { "entries": entries if include_dirs else [], - "next_cursor": None, - "has_more": False, "truncated": False, } try: - if cursor: - mount, local_cursor = self._decode_tree_cursor(cursor) - if mount not in self._mount_to_backend: - return {"error": "Invalid or expired cursor"} - local_path = "/" - else: - mount, local_path = self._split_mount_path(path) - local_cursor = None + mount, local_path = self._split_mount_path(path) except ValueError as exc: return {"error": f"Error: {exc}"} @@ -200,7 +166,6 @@ class MultiRootLocalFolderBackend: local_path, max_depth=max_depth, page_size=page_size, - cursor=local_cursor, include_files=include_files, include_dirs=include_dirs, ) @@ -220,16 +185,8 @@ class MultiRootLocalFolderBackend: } ) - local_next_cursor = self._get_str(result, "next_cursor") - next_cursor = ( - self._encode_tree_cursor(mount, local_next_cursor) - if local_next_cursor - else None - ) return { "entries": entries, - "next_cursor": next_cursor, - "has_more": self._get_bool(result, "has_more"), "truncated": self._get_bool(result, "truncated"), } @@ -239,7 +196,6 @@ class MultiRootLocalFolderBackend: *, max_depth: int | None = 8, page_size: int = 500, - cursor: str | None = None, include_files: bool = True, include_dirs: bool = True, ) -> dict[str, Any]: @@ -248,7 +204,6 @@ class MultiRootLocalFolderBackend: path, max_depth=max_depth, page_size=page_size, - cursor=cursor, include_files=include_files, include_dirs=include_dirs, ) From 7134b0feae70fb7a6eec4172d6cdccf5a9780bad Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 00:57:07 +0530 Subject: [PATCH 202/299] refactor(file_intent): remove _infer_text_file_extension function and standardize fallback filename to 'notes.md' --- .../agents/new_chat/middleware/file_intent.py | 36 ++----------------- .../ui/sidebar/LocalFilesystemBrowser.tsx | 5 +-- 2 files changed, 5 insertions(+), 36 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py index 1e5fd0ede..4bf5dcfe4 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py +++ b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py @@ -109,37 +109,6 @@ def _sanitize_path_segment(value: str) -> str: return segment -def _infer_text_file_extension(user_text: str) -> str: - lowered = user_text.lower() - if any(token in lowered for token in ("json", ".json")): - return ".json" - if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")): - return ".yaml" - if any(token in lowered for token in ("csv", ".csv")): - return ".csv" - if any(token in lowered for token in ("python", ".py")): - return ".py" - if any(token in lowered for token in ("typescript", ".ts", ".tsx")): - return ".ts" - if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")): - return ".js" - if any(token in lowered for token in ("html", ".html")): - return ".html" - if any(token in lowered for token in ("css", ".css")): - return ".css" - if any(token in lowered for token in ("sql", ".sql")): - return ".sql" - if any(token in lowered for token in ("toml", ".toml")): - return ".toml" - if any(token in lowered for token in ("ini", ".ini")): - return ".ini" - if any(token in lowered for token in ("xml", ".xml")): - return ".xml" - if any(token in lowered for token in ("markdown", ".md", "readme")): - return ".md" - return ".md" - - def _normalize_directory(value: str) -> str: raw = value.strip().replace("\\", "/") raw = raw.strip("/") @@ -193,7 +162,6 @@ def _fallback_path( suggested_path: str | None = None, user_text: str, ) -> str: - default_extension = _infer_text_file_extension(user_text) inferred_dir = _infer_directory_from_user_text(user_text) sanitized_filename = "" @@ -202,9 +170,9 @@ def _fallback_path( if sanitized_filename.lower().endswith(".txt"): sanitized_filename = f"{sanitized_filename[:-4]}.md" if not sanitized_filename: - sanitized_filename = f"notes{default_extension}" + sanitized_filename = "notes.md" elif "." not in sanitized_filename: - sanitized_filename = f"{sanitized_filename}{default_extension}" + sanitized_filename = f"{sanitized_filename}.md" normalized_suggested_path = ( _normalize_file_path(suggested_path) if suggested_path else "" diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index d1146338d..a808d5a31 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -1,6 +1,6 @@ "use client"; -import { ChevronDown, ChevronRight, FileText, Folder } from "lucide-react"; +import { ChevronDown, ChevronRight, FileText, Folder, FolderOpen } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { DEFAULT_EXCLUDE_PATTERNS } from "@/components/sources/FolderWatchDialog"; import { Skeleton } from "@/components/ui/skeleton"; @@ -329,6 +329,7 @@ export function LocalFilesystemBrowser({ const renderFolder = useCallback( (folder: LocalFolderNode, depth: number, mount: string) => { const isExpanded = expandedFolderKeys.has(folder.key); + const FolderIcon = isExpanded ? FolderOpen : Folder; const childFolders = Array.from(folder.folders.values()).sort((a, b) => a.name.localeCompare(b.name) ); @@ -347,7 +348,7 @@ export function LocalFilesystemBrowser({ ) : ( )} - + {folder.name} {isExpanded && ( From b85b7cbae0d6a97bd1a0a437a8fdb58f2b250bc4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:12:15 +0530 Subject: [PATCH 203/299] feat(filesystem): introduce support for local openable text file extensions and enhance folder expansion persistence in the UI --- .../src/modules/agent-filesystem.ts | 56 +++++++ surfsense_web/atoms/documents/folder.atoms.ts | 9 ++ .../components/editor/source-code-editor.tsx | 5 + .../ui/sidebar/DesktopLocalTabContent.tsx | 20 ++- .../ui/sidebar/LocalFilesystemBrowser.tsx | 140 ++++++++++++++---- 5 files changed, 202 insertions(+), 28 deletions(-) diff --git a/surfsense_desktop/src/modules/agent-filesystem.ts b/surfsense_desktop/src/modules/agent-filesystem.ts index d8c64b79a..608f8c4a4 100644 --- a/surfsense_desktop/src/modules/agent-filesystem.ts +++ b/surfsense_desktop/src/modules/agent-filesystem.ts @@ -21,6 +21,51 @@ const MAX_LOCAL_ROOTS = 10; const DEFAULT_SPACE_KEY = "default"; let cachedSettingsStore: AgentFilesystemSettingsStore | null = null; +const LOCAL_OPENABLE_TEXT_EXTENSIONS = new Set([ + ".md", + ".markdown", + ".txt", + ".json", + ".yaml", + ".yml", + ".csv", + ".tsv", + ".xml", + ".html", + ".htm", + ".css", + ".scss", + ".sass", + ".sql", + ".toml", + ".ini", + ".conf", + ".log", + ".py", + ".js", + ".jsx", + ".mjs", + ".cjs", + ".ts", + ".tsx", + ".java", + ".kt", + ".kts", + ".go", + ".rs", + ".rb", + ".php", + ".swift", + ".r", + ".lua", + ".sh", + ".bash", + ".zsh", + ".fish", + ".env", + ".mk", +]); + function getSettingsPath(): string { return join(app.getPath("userData"), SETTINGS_FILENAME); } @@ -229,6 +274,16 @@ function toVirtualPath(rootPath: string, absolutePath: string): string { return `/${rel.replace(/\\/g, "/")}`; } +function assertLocalOpenableTextFile(absolutePath: string): void { + const extension = extname(absolutePath).toLowerCase(); + if (!LOCAL_OPENABLE_TEXT_EXTENSIONS.has(extension)) { + throw new Error( + `Unsupported local file type '${extension || "(no extension)"}'. ` + + "Only text/code files can be opened in local mode." + ); + } +} + export type LocalRootMount = { mount: string; rootPath: string; @@ -441,6 +496,7 @@ export async function readAgentLocalFileText( ); } const absolutePath = resolveVirtualPath(rootMount.rootPath, subPath); + assertLocalOpenableTextFile(absolutePath); const content = await readFile(absolutePath, "utf8"); return { path: toMountedVirtualPath(rootMount.mount, rootMount.rootPath, absolutePath), diff --git a/surfsense_web/atoms/documents/folder.atoms.ts b/surfsense_web/atoms/documents/folder.atoms.ts index fe7d556eb..bbdc58e4e 100644 --- a/surfsense_web/atoms/documents/folder.atoms.ts +++ b/surfsense_web/atoms/documents/folder.atoms.ts @@ -12,6 +12,15 @@ export const expandedFolderIdsAtom = atomWithStorage>( {} ); +/** + * Expanded folder keys for Local filesystem tree, keyed by search space ID. + * Persisted so local tree expansion survives remounts/reloads. + */ +export const localExpandedFolderKeysAtom = atomWithStorage>( + "surfsense:localExpandedFolderKeys", + {} +); + /** * Folder currently being renamed (inline edit mode). * null means no folder is being renamed. diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index 5cab8e5b1..27734005e 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -143,6 +143,11 @@ export function SourceCodeEditor({ fontFamily: "ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, Liberation Mono, monospace", renderWhitespace: "selection", + unicodeHighlight: { + ambiguousCharacters: false, + invisibleCharacters: false, + nonBasicASCII: false, + }, smoothScrolling: true, readOnly, }} diff --git a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx index 6fd4e48f8..dd7520d24 100644 --- a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx @@ -1,7 +1,9 @@ "use client"; import { Folder, FolderPlus, Search, X } from "lucide-react"; -import { useRef, useState } from "react"; +import { useAtom } from "jotai"; +import { useCallback, useMemo, useRef, useState } from "react"; +import { localExpandedFolderKeysAtom } from "@/atoms/documents/folder.atoms"; import { Input } from "@/components/ui/input"; import { Separator } from "@/components/ui/separator"; import { @@ -45,6 +47,20 @@ export function DesktopLocalTabContent({ const [localSearch, setLocalSearch] = useState(""); const debouncedLocalSearch = useDebouncedValue(localSearch, 250); const localSearchInputRef = useRef(null); + const [expandedFolderKeyMap, setExpandedFolderKeyMap] = useAtom(localExpandedFolderKeysAtom); + const expandedFolderKeys = useMemo( + () => new Set(expandedFolderKeyMap[searchSpaceId] ?? []), + [expandedFolderKeyMap, searchSpaceId] + ); + const handleExpandedFolderKeysChange = useCallback( + (nextExpandedKeys: Set) => { + setExpandedFolderKeyMap((prev) => ({ + ...prev, + [searchSpaceId]: Array.from(nextExpandedKeys), + })); + }, + [searchSpaceId, setExpandedFolderKeyMap] + ); return (
@@ -181,6 +197,8 @@ export function DesktopLocalTabContent({ active searchQuery={debouncedLocalSearch.trim() || undefined} onOpenFile={onOpenLocalFile} + expandedFolderKeys={expandedFolderKeys} + onExpandedFolderKeysChange={handleExpandedFolderKeysChange} />
); diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index a808d5a31..6bfb1d3f1 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -6,7 +6,6 @@ import { DEFAULT_EXCLUDE_PATTERNS } from "@/components/sources/FolderWatchDialog import { Skeleton } from "@/components/ui/skeleton"; import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; -import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; interface LocalFilesystemBrowserProps { rootPaths: string[]; @@ -14,6 +13,8 @@ interface LocalFilesystemBrowserProps { active?: boolean; searchQuery?: string; onOpenFile: (fullPath: string) => void; + expandedFolderKeys?: Set; + onExpandedFolderKeysChange?: (nextExpandedKeys: Set) => void; } interface LocalFolderFileEntry { @@ -43,6 +44,51 @@ type LocalRootMount = { type MountLoadStatus = "idle" | "loading" | "complete" | "error"; +const LOCAL_OPENABLE_EXTENSIONS = [ + ".md", + ".markdown", + ".txt", + ".json", + ".yaml", + ".yml", + ".csv", + ".tsv", + ".xml", + ".html", + ".htm", + ".css", + ".scss", + ".sass", + ".sql", + ".toml", + ".ini", + ".conf", + ".log", + ".py", + ".js", + ".jsx", + ".mjs", + ".cjs", + ".ts", + ".tsx", + ".java", + ".kt", + ".kts", + ".go", + ".rs", + ".rb", + ".php", + ".swift", + ".r", + ".lua", + ".sh", + ".bash", + ".zsh", + ".fish", + ".env", + ".mk", +]; + const getFolderDisplayName = (rootPath: string): string => rootPath.split(/[\\/]/).at(-1) || rootPath; @@ -73,16 +119,29 @@ function toMountedVirtualPath(mount: string, relativePath: string): string { return `/${mount}${toVirtualPath(relativePath)}`; } +function getNormalizedExtension(pathValue: string): string { + const fileName = getFileName(pathValue).toLowerCase(); + if (!fileName) return ""; + if (fileName === "dockerfile" || fileName === "makefile") { + return `.${fileName}`; + } + const dotIndex = fileName.lastIndexOf("."); + if (dotIndex <= 0) return ""; + return fileName.slice(dotIndex); +} + export function LocalFilesystemBrowser({ rootPaths, searchSpaceId, active = true, searchQuery, onOpenFile, + expandedFolderKeys, + onExpandedFolderKeysChange, }: LocalFilesystemBrowserProps) { const electronAPI = useElectronAPI(); const [rootStateMap, setRootStateMap] = useState>({}); - const [expandedFolderKeys, setExpandedFolderKeys] = useState>(new Set()); + const [internalExpandedFolderKeys, setInternalExpandedFolderKeys] = useState>(new Set()); const [mountByRootKey, setMountByRootKey] = useState>(new Map()); const [mountStatus, setMountStatus] = useState("idle"); const [mountRefreshInFlight, setMountRefreshInFlight] = useState(false); @@ -90,8 +149,9 @@ export function LocalFilesystemBrowser({ const lastLoadedSignatureByRootRef = useRef>(new Map()); const hasLoadedMountsOnceRef = useRef(false); const hasResolvedAtLeastOneRootRef = useRef(false); - const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); + const openableExtensions = useMemo(() => new Set(LOCAL_OPENABLE_EXTENSIONS), []); const isWindowsPlatform = electronAPI?.versions.platform === "win32"; + const effectiveExpandedFolderKeys = expandedFolderKeys ?? internalExpandedFolderKeys; useEffect(() => { if (!active) return; @@ -153,7 +213,6 @@ export function LocalFilesystemBrowser({ rootPath, searchSpaceId, excludePatterns: DEFAULT_EXCLUDE_PATTERNS, - fileExtensions: supportedExtensions, })) as LocalFolderFileEntry[]; if (cancelled) return; setRootStateMap((prev) => ({ @@ -181,7 +240,7 @@ export function LocalFilesystemBrowser({ return () => { cancelled = true; }; - }, [active, electronAPI, isWindowsPlatform, reloadNonceByRoot, rootPaths, searchSpaceId, supportedExtensions]); + }, [active, electronAPI, isWindowsPlatform, reloadNonceByRoot, rootPaths, searchSpaceId]); useEffect(() => { if (active) return; @@ -198,7 +257,13 @@ export function LocalFilesystemBrowser({ return; } - const unsubscribe = electronAPI.onAgentFilesystemTreeDirty((event) => { + const unsubscribe = electronAPI.onAgentFilesystemTreeDirty((event: { + searchSpaceId: number | null; + reason: "watcher_event" | "safety_poll"; + rootPath: string; + changedPath: string | null; + timestamp: number; + }) => { if ((event.searchSpaceId ?? null) !== (searchSpaceId ?? null)) { return; } @@ -225,14 +290,13 @@ export function LocalFilesystemBrowser({ searchSpaceId, rootPaths, excludePatterns: DEFAULT_EXCLUDE_PATTERNS, - fileExtensions: supportedExtensions, }); return () => { unsubscribe(); void electronAPI.stopAgentFilesystemTreeWatch(searchSpaceId); }; - }, [active, electronAPI, isWindowsPlatform, rootPaths, searchSpaceId, supportedExtensions]); + }, [active, electronAPI, isWindowsPlatform, rootPaths, searchSpaceId]); useEffect(() => { if (!electronAPI?.getAgentFilesystemMounts) { @@ -315,7 +379,7 @@ export function LocalFilesystemBrowser({ }, [rootPaths, rootStateMap, searchQuery]); const toggleFolder = useCallback((folderKey: string) => { - setExpandedFolderKeys((prev) => { + const update = (prev: Set) => { const next = new Set(prev); if (next.has(folderKey)) { next.delete(folderKey); @@ -323,12 +387,17 @@ export function LocalFilesystemBrowser({ next.add(folderKey); } return next; - }); - }, []); + }; + if (onExpandedFolderKeysChange) { + onExpandedFolderKeysChange(update(effectiveExpandedFolderKeys)); + return; + } + setInternalExpandedFolderKeys(update); + }, [effectiveExpandedFolderKeys, onExpandedFolderKeysChange]); const renderFolder = useCallback( (folder: LocalFolderNode, depth: number, mount: string) => { - const isExpanded = expandedFolderKeys.has(folder.key); + const isExpanded = effectiveExpandedFolderKeys.has(folder.key); const FolderIcon = isExpanded ? FolderOpen : Folder; const childFolders = Array.from(folder.folders.values()).sort((a, b) => a.name.localeCompare(b.name) @@ -354,26 +423,43 @@ export function LocalFilesystemBrowser({ {isExpanded && ( <> {childFolders.map((childFolder) => renderFolder(childFolder, depth + 1, mount))} - {files.map((file) => ( - - ))} + {files.map((file) => { + const extension = getNormalizedExtension(file.relativePath); + const isOpenable = openableExtensions.has(extension); + return ( + + ); + })} )}
); }, - [expandedFolderKeys, onOpenFile, toggleFolder] + [effectiveExpandedFolderKeys, onOpenFile, openableExtensions, toggleFolder] ); if (rootPaths.length === 0) { From 8c0670929595ce45c087fa2ef2da8dcdc578f037 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:26:33 +0530 Subject: [PATCH 204/299] refactor(source-code-editor): update editor settings by adjusting line number display, disabling folding, and modifying whitespace rendering options --- .../components/editor/source-code-editor.tsx | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index 27734005e..dd4b3bd8e 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -114,10 +114,10 @@ export function SourceCodeEditor({ automaticLayout: true, minimap: { enabled: false }, lineNumbers: "on", - lineNumbersMinChars: 3, - lineDecorationsWidth: 12, + lineNumbersMinChars: 4, + lineDecorationsWidth: 20, glyphMargin: false, - folding: true, + folding: false, overviewRulerLanes: 0, hideCursorInOverviewRuler: true, scrollBeyondLastLine: false, @@ -142,7 +142,12 @@ export function SourceCodeEditor({ fontSize, fontFamily: "ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, Liberation Mono, monospace", - renderWhitespace: "selection", + renderWhitespace: "none", + renderValidationDecorations: "off", + colorDecorators: false, + codeLens: false, + hover: { enabled: false }, + stickyScroll: { enabled: false }, unicodeHighlight: { ambiguousCharacters: false, invisibleCharacters: false, From c238a671c8ec315348965009a2d6334874dbcd61 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:54:26 +0530 Subject: [PATCH 205/299] feat(filesystem): enhance local mount path normalization and add error handling for missing parent directories --- .../agents/new_chat/middleware/filesystem.py | 92 ++++++++++++++++++- .../middleware/local_folder_backend.py | 8 ++ .../middleware/test_file_intent_middleware.py | 4 +- .../test_filesystem_verification.py | 49 ++++++++++ .../middleware/test_local_folder_backend.py | 12 +++ 5 files changed, 160 insertions(+), 5 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 3622bbcdf..8dfa89ef2 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -28,6 +28,9 @@ from langgraph.types import Command from sqlalchemy import delete, select from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( + MultiRootLocalFolderBackend, +) from app.agents.new_chat.sandbox import ( _evict_sandbox_cache, delete_sandbox, @@ -816,6 +819,89 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return normalized return f"/{normalized.lstrip('/')}" + @staticmethod + def _extract_mount_from_path(path: str, mounts: tuple[str, ...]) -> str | None: + rel = path.lstrip("/") + if not rel: + return None + mount, _, _ = rel.partition("/") + if mount in mounts: + return mount + return None + + @staticmethod + def _local_parent_path(path: str) -> str: + rel = path.lstrip("/") + if "/" not in rel: + return "/" + parent = rel.rsplit("/", 1)[0].strip("/") + if not parent: + return "/" + return f"/{parent}" + + @staticmethod + def _path_exists_under_mount( + backend: MultiRootLocalFolderBackend, + mount: str, + local_path: str, + ) -> bool: + result = backend.list_tree( + f"/{mount}{local_path}", + max_depth=0, + page_size=1, + include_files=True, + include_dirs=True, + ) + return not bool(result.get("error")) + + def _normalize_local_mount_path( + self, + candidate: str, + runtime: ToolRuntime[None, FilesystemState], + ) -> str: + normalized = self._normalize_absolute_path(candidate) + backend = self._get_backend(runtime) + if not isinstance(backend, MultiRootLocalFolderBackend): + return normalized + + mounts = backend.list_mounts() + explicit_mount = self._extract_mount_from_path(normalized, mounts) + if explicit_mount: + return normalized + + if len(mounts) == 1: + return f"/{mounts[0]}{normalized}" + + suggested_mount: str | None = None + contract = runtime.state.get("file_operation_contract") or {} + suggested_path = contract.get("suggested_path") + if isinstance(suggested_path, str) and suggested_path.strip(): + normalized_suggested = self._normalize_absolute_path(suggested_path) + suggested_mount = self._extract_mount_from_path(normalized_suggested, mounts) + + matching_mounts = [ + mount + for mount in mounts + if self._path_exists_under_mount(backend, mount, normalized) + ] + if len(matching_mounts) == 1: + return f"/{matching_mounts[0]}{normalized}" + + parent_path = self._local_parent_path(normalized) + if parent_path != "/": + parent_matching_mounts = [ + mount + for mount in mounts + if self._path_exists_under_mount(backend, mount, parent_path) + ] + if len(parent_matching_mounts) == 1: + return f"/{parent_matching_mounts[0]}{normalized}" + + if suggested_mount: + return f"/{suggested_mount}{normalized}" + + return f"/{backend.default_mount()}{normalized}" + def _get_contract_suggested_path( self, runtime: ToolRuntime[None, FilesystemState] ) -> str: @@ -834,7 +920,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not candidate: return self._get_contract_suggested_path(runtime) if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_absolute_path(candidate) + return self._normalize_local_mount_path(candidate, runtime) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate @@ -848,7 +934,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not candidate: return "" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_absolute_path(candidate) + return self._normalize_local_mount_path(candidate, runtime) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate @@ -862,7 +948,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if candidate == "/": return "/" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - return self._normalize_absolute_path(candidate) + return self._normalize_local_mount_path(candidate, runtime) if not candidate.startswith("/"): return f"/{candidate.lstrip('/')}" return candidate diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 4f149a756..0cee3e007 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -180,6 +180,14 @@ class LocalFolderBackend: "Read and then make an edit, or write to a new path." ) ) + parent = path.parent + if not parent.exists() or not parent.is_dir(): + return WriteResult( + error=( + f"Error: parent directory for '{file_path}' does not exist. " + "Create the folder first or write to an existing directory." + ) + ) self._write_text_atomic(path, content) return WriteResult(path=file_path, files_update=None) diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py index c0281fa29..673331b0a 100644 --- a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py +++ b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py @@ -79,7 +79,7 @@ async def test_file_write_null_filename_uses_semantic_default_path(): @pytest.mark.asyncio -async def test_file_write_null_filename_infers_json_extension(): +async def test_file_write_null_filename_defaults_to_markdown_path(): llm = _FakeLLM( '{"intent":"file_write","confidence":0.71,"suggested_filename":null}' ) @@ -94,7 +94,7 @@ async def test_file_write_null_filename_infers_json_extension(): assert result is not None contract = result["file_operation_contract"] assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/notes.json" + assert contract["suggested_path"] == "/notes.md" @pytest.mark.asyncio diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py index 7b4119bb5..d00365032 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py @@ -34,6 +34,11 @@ class _RuntimeNoSuggestedPath: state = {"file_operation_contract": {}} +class _RuntimeWithSuggestedPath: + def __init__(self, suggested_path: str) -> None: + self.state = {"file_operation_contract": {"suggested_path": suggested_path}} + + def test_verify_written_content_prefers_raw_sync() -> None: middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) expected = "line1\nline2" @@ -162,3 +167,47 @@ def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_m resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type] assert resolved == "/pc_backups/var/log/app.log" + + +def test_normalize_local_mount_path_prefers_unique_existing_parent_mount( + tmp_path: Path, +) -> None: + root_a = tmp_path / "RootA" + root_b = tmp_path / "RootB" + (root_a / "other").mkdir(parents=True) + (root_b / "nested" / "deep").mkdir(parents=True) + backend = MultiRootLocalFolderBackend( + (("root_a", str(root_a)), ("root_b", str(root_b))) + ) + runtime = _RuntimeNoSuggestedPath() + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] + + resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] + "/nested/deep/new-note.md", + runtime, + ) + + assert resolved == "/root_b/nested/deep/new-note.md" + + +def test_normalize_local_mount_path_uses_suggested_mount_when_ambiguous( + tmp_path: Path, +) -> None: + root_a = tmp_path / "RootA" + root_b = tmp_path / "RootB" + root_a.mkdir(parents=True) + root_b.mkdir(parents=True) + backend = MultiRootLocalFolderBackend( + (("root_a", str(root_a)), ("root_b", str(root_b))) + ) + runtime = _RuntimeWithSuggestedPath("/root_b/notes/context.md") + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign] + + resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type] + "/brand-new-note.md", + runtime, + ) + + assert resolved == "/root_b/brand-new-note.md" diff --git a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py index 3484a2cc4..7dfc68402 100644 --- a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py @@ -9,6 +9,7 @@ pytestmark = pytest.mark.unit def test_local_backend_write_read_edit_roundtrip(tmp_path: Path): backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "notes").mkdir() write = backend.write("/notes/test.md", "line1\nline2") assert write.error is None @@ -51,9 +52,20 @@ def test_local_backend_glob_and_grep(tmp_path: Path): def test_local_backend_read_raw_returns_exact_content(tmp_path: Path): backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "notes").mkdir() expected = "# Title\n\nline 1\nline 2\n" write = backend.write("/notes/raw.md", expected) assert write.error is None raw = backend.read_raw("/notes/raw.md") assert raw == expected + + +def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + + write = backend.write("/tempoo/new-note.md", "# New note") + + assert write.error is not None + assert "parent directory" in write.error + assert not (tmp_path / "tempoo").exists() From 8d50f90060f8e53c4a5f2ddda88bed2198981938 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Mon, 27 Apr 2026 14:04:50 -0700 Subject: [PATCH 206/299] chore: linting --- .../app/agents/new_chat/chat_deepagent.py | 12 +- .../agents/new_chat/middleware/__init__.py | 6 +- .../agents/new_chat/middleware/file_intent.py | 9 +- .../agents/new_chat/middleware/filesystem.py | 36 +++- .../new_chat/middleware/knowledge_search.py | 2 +- .../middleware/local_folder_backend.py | 44 +++-- .../multi_root_local_folder_backend.py | 28 +++- .../new_chat/tools/connected_accounts.py | 12 +- .../agents/new_chat/tools/discord/_auth.py | 3 +- .../new_chat/tools/discord/list_channels.py | 32 +++- .../new_chat/tools/discord/read_messages.py | 32 +++- .../new_chat/tools/discord/send_message.py | 35 +++- .../agents/new_chat/tools/gmail/read_email.py | 21 ++- .../new_chat/tools/gmail/search_emails.py | 45 +++-- .../tools/google_calendar/search_events.py | 54 ++++-- .../app/agents/new_chat/tools/hitl.py | 4 +- .../app/agents/new_chat/tools/luma/_auth.py | 3 +- .../new_chat/tools/luma/create_event.py | 21 ++- .../agents/new_chat/tools/luma/list_events.py | 37 ++-- .../agents/new_chat/tools/luma/read_event.py | 16 +- .../app/agents/new_chat/tools/mcp_client.py | 6 +- .../app/agents/new_chat/tools/mcp_tool.py | 158 +++++++++++------- .../app/agents/new_chat/tools/registry.py | 13 +- .../app/agents/new_chat/tools/teams/_auth.py | 3 +- .../new_chat/tools/teams/list_channels.py | 33 +++- .../new_chat/tools/teams/read_messages.py | 32 ++-- .../new_chat/tools/teams/send_message.py | 24 ++- .../agents/new_chat/tools/tool_response.py | 5 +- .../app/connectors/exceptions.py | 1 - surfsense_backend/app/routes/__init__.py | 4 +- .../app/routes/mcp_oauth_route.py | 130 ++++++++++---- .../app/routes/new_chat_routes.py | 2 +- .../app/routes/oauth_connector_base.py | 29 ++-- .../routes/search_source_connectors_routes.py | 4 +- .../app/services/mcp_oauth/discovery.py | 4 +- .../app/services/mcp_oauth/registry.py | 62 ++++--- .../app/services/obsidian_plugin_indexer.py | 9 +- .../app/tasks/chat/stream_new_chat.py | 41 ++--- surfsense_backend/app/utils/async_retry.py | 9 +- .../app/utils/connector_naming.py | 5 +- .../test_obsidian_plugin_routes.py | 16 +- .../middleware/test_file_intent_middleware.py | 10 +- .../test_filesystem_verification.py | 4 +- .../unit/test_obsidian_plugin_indexer.py | 7 +- .../unit/test_stream_new_chat_contract.py | 1 - .../new-chat/[[...chat_id]]/page.tsx | 23 +-- .../components/DesktopShortcutsContent.tsx | 50 +++--- surfsense_web/app/desktop/login/page.tsx | 11 +- .../assistant-ui/connector-popup.tsx | 50 +++--- .../components/mcp-connect-form.tsx | 14 +- .../components/mcp-config.tsx | 14 +- .../components/teams-config.tsx | 6 +- .../views/connector-edit-view.tsx | 12 +- .../views/indexing-configuration-view.tsx | 5 +- .../tabs/active-connectors-tab.tsx | 6 +- .../views/connector-accounts-list-view.tsx | 129 +++++++------- .../components/assistant-ui/markdown-text.tsx | 7 +- .../components/editor-panel/editor-panel.tsx | 141 +++++++++------- .../editor/plugins/fixed-toolbar-kit.tsx | 3 +- .../components/editor/source-code-editor.tsx | 2 +- .../layout/ui/right-panel/RightPanel.tsx | 8 +- .../ui/sidebar/DesktopLocalTabContent.tsx | 6 +- .../layout/ui/sidebar/DocumentsSidebar.tsx | 62 ++++--- .../ui/sidebar/LocalFilesystemBrowser.tsx | 109 ++++++------ .../layout/ui/tabs/DocumentTabContent.tsx | 4 +- .../components/new-chat/model-selector.tsx | 10 +- .../components/report-panel/report-panel.tsx | 3 +- .../settings/agent-model-manager.tsx | 10 +- .../components/settings/roles-manager.tsx | 32 +++- .../settings/user-settings-dialog.tsx | 17 +- .../tool-ui/generic-hitl-approval.tsx | 4 +- .../tool-ui/google-calendar/create-event.tsx | 9 +- surfsense_web/contracts/enums/toolIcons.tsx | 2 +- surfsense_web/types/window.d.ts | 15 +- 74 files changed, 1135 insertions(+), 693 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 73a39ccbf..ddf87cf2a 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -50,7 +50,10 @@ from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, build_surfsense_system_prompt, ) -from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools +from app.agents.new_chat.tools.registry import ( + build_tools_async, + get_connector_gated_tools, +) from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger @@ -294,9 +297,7 @@ async def create_surfsense_deep_agent( } modified_disabled_tools = list(disabled_tools) if disabled_tools else [] - modified_disabled_tools.extend( - get_connector_gated_tools(available_connectors) - ) + modified_disabled_tools.extend(get_connector_gated_tools(available_connectors)) # Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware. if "search_knowledge_base" not in modified_disabled_tools: @@ -328,7 +329,8 @@ async def create_surfsense_deep_agent( meta = getattr(t, "metadata", None) or {} if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"): _mcp_connector_tools.setdefault( - meta["mcp_connector_name"], [], + meta["mcp_connector_name"], + [], ).append(t.name) if _mcp_connector_tools: diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 5a24b2f9e..6e4542e1a 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -3,12 +3,12 @@ from app.agents.new_chat.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, ) -from app.agents.new_chat.middleware.filesystem import ( - SurfSenseFilesystemMiddleware, -) from app.agents.new_chat.middleware.file_intent import ( FileIntentMiddleware, ) +from app.agents.new_chat.middleware.filesystem import ( + SurfSenseFilesystemMiddleware, +) from app.agents.new_chat.middleware.knowledge_search import ( KnowledgeBaseSearchMiddleware, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py index 4bf5dcfe4..05cb230ce 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py +++ b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py @@ -213,7 +213,9 @@ def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str ) -def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str: +def _build_recent_conversation( + messages: list[BaseMessage], *, max_messages: int = 6 +) -> str: rows: list[str] = [] for msg in messages[-max_messages:]: role = "user" if isinstance(msg, HumanMessage) else "assistant" @@ -246,7 +248,9 @@ class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] [HumanMessage(content=prompt)], config={"tags": ["surfsense:internal"]}, ) - payload = json.loads(_extract_json_payload(_extract_text_from_message(response))) + payload = json.loads( + _extract_json_payload(_extract_text_from_message(response)) + ) plan = FileIntentPlan.model_validate(payload) return plan except (json.JSONDecodeError, ValidationError, ValueError) as exc: @@ -317,4 +321,3 @@ class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] insert_at = max(len(new_messages) - 1, 0) new_messages.insert(insert_at, contract_msg) return {"messages": new_messages, "file_operation_contract": contract} - diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 8dfa89ef2..cb50693f1 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -877,7 +877,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): suggested_path = contract.get("suggested_path") if isinstance(suggested_path, str) and suggested_path.strip(): normalized_suggested = self._normalize_absolute_path(suggested_path) - suggested_mount = self._extract_mount_from_path(normalized_suggested, mounts) + suggested_mount = self._extract_mount_from_path( + normalized_suggested, mounts + ) matching_mounts = [ mount @@ -1071,14 +1073,18 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = False, ) -> Command | str: if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: move_file is only available in desktop local-folder mode." + return ( + "Error: move_file is only available in desktop local-folder mode." + ) if not source_path.strip() or not destination_path.strip(): return "Error: source_path and destination_path are required." resolved_backend = self._get_backend(runtime) source_target = self._resolve_move_target_path(source_path, runtime) - destination_target = self._resolve_move_target_path(destination_path, runtime) + destination_target = self._resolve_move_target_path( + destination_path, runtime + ) try: validated_source = validate_path(source_target) validated_destination = validate_path(destination_target) @@ -1106,7 +1112,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ], } ) - return f"Moved '{validated_source}' to '{res.path or validated_destination}'" + return ( + f"Moved '{validated_source}' to '{res.path or validated_destination}'" + ) async def async_move_file( source_path: Annotated[ @@ -1125,14 +1133,18 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = False, ) -> Command | str: if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: move_file is only available in desktop local-folder mode." + return ( + "Error: move_file is only available in desktop local-folder mode." + ) if not source_path.strip() or not destination_path.strip(): return "Error: source_path and destination_path are required." resolved_backend = self._get_backend(runtime) source_target = self._resolve_move_target_path(source_path, runtime) - destination_target = self._resolve_move_target_path(destination_path, runtime) + destination_target = self._resolve_move_target_path( + destination_path, runtime + ) try: validated_source = validate_path(source_target) validated_destination = validate_path(destination_target) @@ -1160,7 +1172,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ], } ) - return f"Moved '{validated_source}' to '{res.path or validated_destination}'" + return ( + f"Moved '{validated_source}' to '{res.path or validated_destination}'" + ) return StructuredTool.from_function( name="move_file", @@ -1201,7 +1215,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = True, ) -> str: if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: list_tree is only available in desktop local-folder mode." + return ( + "Error: list_tree is only available in desktop local-folder mode." + ) if max_depth < 0: return "Error: max_depth must be >= 0." if page_size < 1: @@ -1253,7 +1269,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): ] = True, ) -> str: if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: list_tree is only available in desktop local-folder mode." + return ( + "Error: list_tree is only available in desktop local-folder mode." + ) if max_depth < 0: return "Error: max_depth must be >= 0." if page_size < 1: diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 51378a013..6df317aaa 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -27,8 +27,8 @@ from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range from app.db import ( NATIVE_TO_LEGACY_DOCTYPE, Chunk, diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 0cee3e007..565fcb48b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -120,7 +120,9 @@ class LocalFolderBackend: if not target.exists() or not target.is_dir(): return [] infos: list[FileInfo] = [] - for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())): + for child in sorted( + target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()) + ): infos.append( FileInfo( path=self._to_virtual(child, self._root), @@ -317,7 +319,9 @@ class LocalFolderBackend: return WriteResult(error="Error: source and destination paths are the same") with self._acquire_path_locks(source_path, destination_path): if not source.exists(): - return WriteResult(error=f"Error: source path '{source_path}' not found") + return WriteResult( + error=f"Error: source path '{source_path}' not found" + ) if destination.exists(): if not overwrite: return WriteResult( @@ -339,8 +343,12 @@ class LocalFolderBackend: else: source.rename(destination) except OSError as exc: - return WriteResult(error=f"Error: failed to move '{source_path}': {exc}") - return WriteResult(path=self._to_virtual(destination, self._root), files_update=None) + return WriteResult( + error=f"Error: failed to move '{source_path}': {exc}" + ) + return WriteResult( + path=self._to_virtual(destination, self._root), files_update=None + ) async def amove( self, @@ -368,12 +376,16 @@ class LocalFolderBackend: if not path.exists() or not path.is_file(): return EditResult(error=f"Error: File '{file_path}' not found") content = path.read_text(encoding="utf-8", errors="replace") - result = perform_string_replacement(content, old_string, new_string, replace_all) + result = perform_string_replacement( + content, old_string, new_string, replace_all + ) if isinstance(result, str): return EditResult(error=result) updated_content, occurrences = result self._write_text_atomic(path, updated_content) - return EditResult(path=file_path, files_update=None, occurrences=int(occurrences)) + return EditResult( + path=file_path, files_update=None, occurrences=int(occurrences) + ) async def aedit( self, @@ -447,7 +459,9 @@ class LocalFolderBackend: matches: list[GrepMatch] = [] for file_path in self._iter_candidate_files(path, glob): try: - lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines() + lines = file_path.read_text( + encoding="utf-8", errors="replace" + ).splitlines() except Exception: continue for idx, line in enumerate(lines, start=1): @@ -481,12 +495,18 @@ class LocalFolderBackend: FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND) ) except IsADirectoryError: - responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY)) + responses.append( + FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY) + ) except Exception: - responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH)) + responses.append( + FileUploadResponse(path=virtual_path, error=_INVALID_PATH) + ) return responses - async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + async def aupload_files( + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: return await asyncio.to_thread(self.upload_files, files) def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: @@ -515,7 +535,9 @@ class LocalFolderBackend: ) except Exception: responses.append( - FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH) + FileDownloadResponse( + path=virtual_path, content=None, error=_INVALID_PATH + ) ) return responses diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 82914f9ce..93eabe6ff 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -127,7 +127,9 @@ class MultiRootLocalFolderBackend: mount, local_path = self._split_mount_path(path) except ValueError: return [] - return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path)) + return self._transform_infos( + mount, self._mount_to_backend[mount].ls_info(local_path) + ) async def als_info(self, path: str) -> list[FileInfo]: return await asyncio.to_thread(self.ls_info, path) @@ -355,7 +357,9 @@ class MultiRootLocalFolderBackend: all_matches.extend( [ GrepMatch( - path=self._prefix_mount_path(mount, self._get_str(match, "path")), + path=self._prefix_mount_path( + mount, self._get_str(match, "path") + ), line=self._get_int(match, "line"), text=self._get_str(match, "text"), ) @@ -394,7 +398,9 @@ class MultiRootLocalFolderBackend: try: mount, local_path = self._split_mount_path(virtual_path) except ValueError: - invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH)) + invalid.append( + FileUploadResponse(path=virtual_path, error=_INVALID_PATH) + ) continue grouped.setdefault(mount, []).append((local_path, content)) @@ -404,7 +410,9 @@ class MultiRootLocalFolderBackend: responses.extend( [ FileUploadResponse( - path=self._prefix_mount_path(mount, self._get_str(item, "path")), + path=self._prefix_mount_path( + mount, self._get_str(item, "path") + ), error=self._get_str(item, "error") or None, ) for item in result @@ -412,7 +420,9 @@ class MultiRootLocalFolderBackend: ) return responses - async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + async def aupload_files( + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: return await asyncio.to_thread(self.upload_files, files) def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: @@ -423,7 +433,9 @@ class MultiRootLocalFolderBackend: mount, local_path = self._split_mount_path(virtual_path) except ValueError: invalid.append( - FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH) + FileDownloadResponse( + path=virtual_path, content=None, error=_INVALID_PATH + ) ) continue grouped.setdefault(mount, []).append(local_path) @@ -434,7 +446,9 @@ class MultiRootLocalFolderBackend: responses.extend( [ FileDownloadResponse( - path=self._prefix_mount_path(mount, self._get_str(item, "path")), + path=self._prefix_mount_path( + mount, self._get_str(item, "path") + ), content=self._get_value(item, "content"), error=self._get_str(item, "error") or None, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py index e0b1978e1..5675a42e6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py +++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py @@ -57,7 +57,11 @@ def create_get_connected_accounts_tool( async def _run(service: str) -> list[dict[str, Any]]: svc_cfg = MCP_SERVICES.get(service) if not svc_cfg: - return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}] + return [ + { + "error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}" + } + ] try: connector_type = SearchSourceConnectorType(svc_cfg.connector_type) @@ -74,7 +78,11 @@ def create_get_connected_accounts_tool( connectors = result.scalars().all() if not connectors: - return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}] + return [ + { + "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + } + ] is_multi = len(connectors) > 1 diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py index 1f51e3660..c345f8a5e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py @@ -19,7 +19,8 @@ async def get_discord_connector( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DISCORD_CONNECTOR, ) ) return result.scalars().first() diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py index a33b88aa0..3cc99ac17 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -23,16 +23,24 @@ def create_list_discord_channels_tool( Dictionary with status and a list of channels (id, name). """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Discord tool not properly configured."} + return { + "status": "error", + "message": "Discord tool not properly configured.", + } try: - connector = await get_discord_connector(db_session, search_space_id, user_id) + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) if not connector: return {"status": "error", "message": "No Discord connector found."} guild_id = get_guild_id(connector) if not guild_id: - return {"status": "error", "message": "No guild ID in Discord connector config."} + return { + "status": "error", + "message": "No guild ID in Discord connector config.", + } token = get_bot_token(connector) @@ -44,9 +52,16 @@ def create_list_discord_channels_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } if resp.status_code != 200: - return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } # Type 0 = text channel channels = [ @@ -54,7 +69,12 @@ def create_list_discord_channels_tool( for ch in resp.json() if ch.get("type") == 0 ] - return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)} + return { + "status": "success", + "guild_id": guild_id, + "channels": channels, + "total": len(channels), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py index 852a9297b..d8bf989a1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -31,12 +31,17 @@ def create_read_discord_messages_tool( id, author, content, timestamp. """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Discord tool not properly configured."} + return { + "status": "error", + "message": "Discord tool not properly configured.", + } limit = min(limit, 50) try: - connector = await get_discord_connector(db_session, search_space_id, user_id) + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) if not connector: return {"status": "error", "message": "No Discord connector found."} @@ -51,11 +56,21 @@ def create_read_discord_messages_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } if resp.status_code == 403: - return {"status": "error", "message": "Bot lacks permission to read this channel."} + return { + "status": "error", + "message": "Bot lacks permission to read this channel.", + } if resp.status_code != 200: - return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } messages = [ { @@ -67,7 +82,12 @@ def create_read_discord_messages_tool( for m in resp.json() ] - return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)} + return { + "status": "success", + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py index be4e6fdb2..236cd017a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -35,13 +35,21 @@ def create_send_discord_message_tool( - If status is "rejected", the user explicitly declined. Do NOT retry. """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Discord tool not properly configured."} + return { + "status": "error", + "message": "Discord tool not properly configured.", + } if len(content) > 2000: - return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."} + return { + "status": "error", + "message": "Message exceeds Discord's 2000-character limit.", + } try: - connector = await get_discord_connector(db_session, search_space_id, user_id) + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) if not connector: return {"status": "error", "message": "No Discord connector found."} @@ -53,7 +61,10 @@ def create_send_discord_message_tool( ) if result.rejected: - return {"status": "rejected", "message": "User declined. Message was not sent."} + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } final_content = result.params.get("content", content) final_channel = result.params.get("channel_id", channel_id) @@ -72,11 +83,21 @@ def create_send_discord_message_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } if resp.status_code == 403: - return {"status": "error", "message": "Bot lacks permission to send messages in this channel."} + return { + "status": "error", + "message": "Bot lacks permission to send messages in this channel.", + } if resp.status_code not in (200, 201): - return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } msg_data = resp.json() return { diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py index 9071f129a..deec1627c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -65,12 +65,22 @@ def create_read_gmail_email_tool( detail, error = await gmail.get_message_details(message_id) if error: - if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): - return {"status": "auth_error", "message": error, "connector_type": "gmail"} + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } return {"status": "error", "message": error} if not detail: - return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."} + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } content = gmail.format_message_to_markdown(detail) @@ -82,6 +92,9 @@ def create_read_gmail_email_tool( if isinstance(e, GraphInterrupt): raise logger.error("Error reading Gmail email: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to read email. Please try again."} + return { + "status": "error", + "message": "Failed to read email. Please try again.", + } return read_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index de43f03d0..2e363609e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -125,12 +125,24 @@ def create_search_gmail_tool( max_results=max_results, query=query ) if error: - if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): - return {"status": "auth_error", "message": error, "connector_type": "gmail"} + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } return {"status": "error", "message": error} if not messages_list: - return {"status": "success", "emails": [], "total": 0, "message": "No emails found."} + return { + "status": "success", + "emails": [], + "total": 0, + "message": "No emails found.", + } emails = [] for msg in messages_list: @@ -141,16 +153,18 @@ def create_search_gmail_tool( h["name"].lower(): h["value"] for h in detail.get("payload", {}).get("headers", []) } - emails.append({ - "message_id": detail.get("id"), - "thread_id": detail.get("threadId"), - "subject": headers.get("subject", "No Subject"), - "from": headers.get("from", "Unknown"), - "to": headers.get("to", ""), - "date": headers.get("date", ""), - "snippet": detail.get("snippet", ""), - "labels": detail.get("labelIds", []), - }) + emails.append( + { + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + } + ) return {"status": "success", "emails": emails, "total": len(emails)} @@ -160,6 +174,9 @@ def create_search_gmail_tool( if isinstance(e, GraphInterrupt): raise logger.error("Error searching Gmail: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to search Gmail. Please try again."} + return { + "status": "error", + "message": "Failed to search Gmail. Please try again.", + } return search_gmail diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index a622b0efa..dc6adb822 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -39,7 +39,10 @@ def create_search_calendar_events_tool( event_id, summary, start, end, location, attendees. """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Calendar tool not properly configured."} + return { + "status": "error", + "message": "Calendar tool not properly configured.", + } max_results = min(max_results, 50) @@ -76,10 +79,22 @@ def create_search_calendar_events_tool( ) if error: - if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): - return {"status": "auth_error", "message": error, "connector_type": "google_calendar"} + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "google_calendar", + } if "no events found" in error.lower(): - return {"status": "success", "events": [], "total": 0, "message": error} + return { + "status": "success", + "events": [], + "total": 0, + "message": error, + } return {"status": "error", "message": error} events = [] @@ -87,19 +102,19 @@ def create_search_calendar_events_tool( start = ev.get("start", {}) end = ev.get("end", {}) attendees_raw = ev.get("attendees", []) - events.append({ - "event_id": ev.get("id"), - "summary": ev.get("summary", "No Title"), - "start": start.get("dateTime") or start.get("date", ""), - "end": end.get("dateTime") or end.get("date", ""), - "location": ev.get("location", ""), - "description": ev.get("description", ""), - "html_link": ev.get("htmlLink", ""), - "attendees": [ - a.get("email", "") for a in attendees_raw[:10] - ], - "status": ev.get("status", ""), - }) + events.append( + { + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [a.get("email", "") for a in attendees_raw[:10]], + "status": ev.get("status", ""), + } + ) return {"status": "success", "events": events, "total": len(events)} @@ -109,6 +124,9 @@ def create_search_calendar_events_tool( if isinstance(e, GraphInterrupt): raise logger.error("Error searching calendar events: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to search calendar events. Please try again."} + return { + "status": "error", + "message": "Failed to search calendar events. Please try again.", + } return search_calendar_events diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 89f02abf6..8480e57b1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -130,7 +130,9 @@ def request_approval( try: decision_type, edited_params = _parse_decision(approval) except ValueError: - logger.warning("No approval decision received for %s — rejecting for safety", tool_name) + logger.warning( + "No approval decision received for %s — rejecting for safety", tool_name + ) return HITLResult(rejected=True, decision_type="error", params=params) logger.info("User decision for %s: %s", tool_name, decision_type) diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py index 1d88161d6..37deb1525 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py @@ -17,7 +17,8 @@ async def get_luma_connector( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LUMA_CONNECTOR, ) ) return result.scalars().first() diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py index 2217d29e6..0a24a988f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -62,7 +62,10 @@ def create_create_luma_event_tool( ) if result.rejected: - return {"status": "rejected", "message": "User declined. Event was not created."} + return { + "status": "rejected", + "message": "User declined. Event was not created.", + } final_name = result.params.get("name", name) final_start = result.params.get("start_at", start_at) @@ -90,11 +93,21 @@ def create_create_luma_event_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } if resp.status_code == 403: - return {"status": "error", "message": "Luma Plus subscription required to create events via API."} + return { + "status": "error", + "message": "Luma Plus subscription required to create events via API.", + } if resp.status_code not in (200, 201): - return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"} + return { + "status": "error", + "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", + } data = resp.json() event_id = data.get("api_id") or data.get("event", {}).get("api_id") diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py index cd4721758..aec5ad220 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -46,7 +46,9 @@ def create_list_luma_events_tool( async with httpx.AsyncClient(timeout=20.0) as client: while len(all_entries) < max_results: - params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))} + params: dict[str, Any] = { + "limit": min(100, max_results - len(all_entries)) + } if cursor: params["cursor"] = cursor @@ -57,9 +59,16 @@ def create_list_luma_events_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } if resp.status_code != 200: - return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } data = resp.json() entries = data.get("entries", []) @@ -76,16 +85,18 @@ def create_list_luma_events_tool( for entry in all_entries[:max_results]: ev = entry.get("event", {}) geo = ev.get("geo_info", {}) - events.append({ - "event_id": entry.get("api_id"), - "name": ev.get("name", "Untitled"), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location": geo.get("name", ""), - "url": ev.get("url", ""), - "visibility": ev.get("visibility", ""), - }) + events.append( + { + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + } + ) return {"status": "success", "events": events, "total": len(events)} diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py index eb3ac55c6..b37a9d617 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -44,11 +44,21 @@ def create_read_luma_event_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } if resp.status_code == 404: - return {"status": "not_found", "message": f"Event '{event_id}' not found."} + return { + "status": "not_found", + "message": f"Event '{event_id}' not found.", + } if resp.status_code != 200: - return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } data = resp.json() ev = data.get("event", data) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py index b46ddbcc5..e28ac8bda 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -220,10 +220,8 @@ class MCPClient: logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) return result_str - except asyncio.TimeoutError: - logger.error( - "MCP tool '%s' timed out after %.0fs", tool_name, timeout - ) + except TimeoutError: + logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout) return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s" except RuntimeError as e: if "Invalid structured content" in str(e): diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index dfee24516..5b96ab374 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -35,7 +35,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type logger = logging.getLogger(__name__) @@ -105,13 +105,15 @@ def _create_dynamic_input_model_from_schema( description=( "Arguments to pass to this tool as a JSON object. " "Infer sensible key names from the tool name and description " - "(e.g. {\"search\": \"my query\"} for a search tool)." + '(e.g. {"search": "my query"} for a search tool).' ), ), ) model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" - model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions) + model = create_model( + model_name, __config__=ConfigDict(extra="allow"), **field_definitions + ) return model @@ -187,16 +189,23 @@ async def _create_mcp_tool_from_definition_stdio( except Exception as e: last_error = e if attempt < _TOOL_CALL_MAX_RETRIES - 1: - delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt) + delay = _TOOL_CALL_RETRY_DELAY * (2**attempt) logger.warning( "MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...", - tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay, + tool_name, + attempt + 1, + _TOOL_CALL_MAX_RETRIES, + e, + delay, ) await asyncio.sleep(delay) else: logger.error( "MCP tool '%s' failed after %d attempts: %s", - tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True, + tool_name, + _TOOL_CALL_MAX_RETRIES, + e, + exc_info=True, ) return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}" @@ -318,17 +327,22 @@ async def _create_mcp_tool_from_definition_http( try: result_str = await _do_mcp_call(headers, call_kwargs) - logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) + logger.debug( + "MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str) + ) return result_str except Exception as first_err: if not _is_auth_error(first_err) or connector_id is None: - logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err) + logger.exception( + "MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err + ) return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}" logger.warning( "MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s", - exposed_name, connector_id, + exposed_name, + connector_id, ) fresh_headers = await _force_refresh_and_get_headers(connector_id) if fresh_headers is None: @@ -348,7 +362,8 @@ async def _create_mcp_tool_from_definition_http( except Exception as retry_err: logger.exception( "MCP HTTP tool '%s' still failing after token refresh: %s", - exposed_name, retry_err, + exposed_name, + retry_err, ) if _is_auth_error(retry_err): await _mark_connector_auth_expired(connector_id) @@ -393,7 +408,8 @@ async def _load_stdio_mcp_tools( if not command or not isinstance(command, str): logger.warning( "MCP connector %d (name: '%s') missing or invalid command field, skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -401,7 +417,8 @@ async def _load_stdio_mcp_tools( if not isinstance(args, list): logger.warning( "MCP connector %d (name: '%s') has invalid args field (must be list), skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -409,7 +426,8 @@ async def _load_stdio_mcp_tools( if not isinstance(env, dict): logger.warning( "MCP connector %d (name: '%s') has invalid env field (must be dict), skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -420,7 +438,9 @@ async def _load_stdio_mcp_tools( logger.info( "Discovered %d tools from stdio MCP server '%s' (connector %d)", - len(tool_definitions), command, connector_id, + len(tool_definitions), + command, + connector_id, ) for tool_def in tool_definitions: @@ -436,7 +456,9 @@ async def _load_stdio_mcp_tools( except Exception as e: logger.exception( "Failed to create tool '%s' from connector %d: %s", - tool_def.get("name"), connector_id, e, + tool_def.get("name"), + connector_id, + e, ) return tools @@ -468,7 +490,8 @@ async def _load_http_mcp_tools( if not url or not isinstance(url, str): logger.warning( "MCP connector %d (name: '%s') missing or invalid url field, skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -476,7 +499,8 @@ async def _load_http_mcp_tools( if not isinstance(headers, dict): logger.warning( "MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -507,7 +531,9 @@ async def _load_http_mcp_tools( if not _is_auth_error(first_err) or connector_id is None: logger.exception( "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", - url, connector_id, first_err, + url, + connector_id, + first_err, ) return tools @@ -534,7 +560,8 @@ async def _load_http_mcp_tools( except Exception as retry_err: logger.exception( "HTTP MCP discovery for connector %d still failing after refresh: %s", - connector_id, retry_err, + connector_id, + retry_err, ) if _is_auth_error(retry_err): await _mark_connector_auth_expired(connector_id) @@ -543,17 +570,20 @@ async def _load_http_mcp_tools( total_discovered = len(tool_definitions) if allowed_set: - tool_definitions = [ - td for td in tool_definitions if td["name"] in allowed_set - ] + tool_definitions = [td for td in tool_definitions if td["name"] in allowed_set] logger.info( "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", - url, connector_id, len(tool_definitions), total_discovered, + url, + connector_id, + len(tool_definitions), + total_discovered, ) else: logger.info( "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", - total_discovered, url, connector_id, + total_discovered, + url, + connector_id, ) for tool_def in tool_definitions: @@ -573,7 +603,9 @@ async def _load_http_mcp_tools( except Exception as e: logger.exception( "Failed to create HTTP tool '%s' from connector %d: %s", - tool_def.get("name"), connector_id, e, + tool_def.get("name"), + connector_id, + e, ) return tools @@ -628,7 +660,7 @@ def _inject_oauth_headers( async def _refresh_connector_token( session: AsyncSession, - connector: "SearchSourceConnector", + connector: SearchSourceConnector, ) -> str | None: """Refresh the OAuth token for an MCP connector and persist the result. @@ -692,12 +724,8 @@ async def _refresh_connector_token( updated_oauth = dict(mcp_oauth) updated_oauth["access_token"] = enc.encrypt_token(new_access) if token_json.get("refresh_token"): - updated_oauth["refresh_token"] = enc.encrypt_token( - token_json["refresh_token"] - ) - updated_oauth["expires_at"] = ( - new_expires_at.isoformat() if new_expires_at else None - ) + updated_oauth["refresh_token"] = enc.encrypt_token(token_json["refresh_token"]) + updated_oauth["expires_at"] = new_expires_at.isoformat() if new_expires_at else None updated_cfg = {**cfg, "mcp_oauth": updated_oauth} updated_cfg.pop("auth_expired", None) @@ -713,7 +741,7 @@ async def _refresh_connector_token( async def _maybe_refresh_mcp_oauth_token( session: AsyncSession, - connector: "SearchSourceConnector", + connector: SearchSourceConnector, cfg: dict[str, Any], server_config: dict[str, Any], ) -> dict[str, Any]: @@ -731,10 +759,11 @@ async def _maybe_refresh_mcp_oauth_token( try: expires_at = datetime.fromisoformat(expires_at_str) if expires_at.tzinfo is None: - from datetime import timezone - expires_at = expires_at.replace(tzinfo=timezone.utc) + expires_at = expires_at.replace(tzinfo=UTC) - if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS): + if datetime.now(UTC) < expires_at - timedelta( + seconds=_TOKEN_REFRESH_BUFFER_SECONDS + ): return server_config except (ValueError, TypeError): return server_config @@ -744,7 +773,9 @@ async def _maybe_refresh_mcp_oauth_token( if not new_access: return server_config - logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id) + logger.info( + "Proactively refreshed MCP OAuth token for connector %s", connector.id + ) refreshed_config = dict(server_config) refreshed_config["headers"] = { @@ -920,7 +951,7 @@ async def load_mcp_tools( result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, - cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, JSONB).has_key("server_config"), ), ) @@ -956,13 +987,17 @@ async def load_mcp_tools( if not server_config or not isinstance(server_config, dict): logger.warning( "MCP connector %d (name: '%s') has invalid or missing server_config, skipping", - connector.id, connector.name, + connector.id, + connector.name, ) continue if cfg.get("mcp_oauth"): server_config = await _maybe_refresh_mcp_oauth_token( - session, connector, cfg, server_config, + session, + connector, + cfg, + server_config, ) cfg = connector.config or {} server_config = _inject_oauth_headers(cfg, server_config) @@ -995,22 +1030,25 @@ async def load_mcp_tools( if service_key: tool_name_prefix = f"{service_key}_{connector.id}" - discovery_tasks.append({ - "connector_id": connector.id, - "connector_name": connector.name, - "server_config": server_config, - "trusted_tools": trusted_tools, - "allowed_tools": allowed_tools, - "readonly_tools": readonly_tools, - "tool_name_prefix": tool_name_prefix, - "transport": server_config.get("transport", "stdio"), - "is_generic_mcp": svc_cfg is None, - }) + discovery_tasks.append( + { + "connector_id": connector.id, + "connector_name": connector.name, + "server_config": server_config, + "trusted_tools": trusted_tools, + "allowed_tools": allowed_tools, + "readonly_tools": readonly_tools, + "tool_name_prefix": tool_name_prefix, + "transport": server_config.get("transport", "stdio"), + "is_generic_mcp": svc_cfg is None, + } + ) except Exception as e: logger.exception( "Failed to prepare MCP connector %d: %s", - connector.id, e, + connector.id, + e, ) async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: @@ -1039,23 +1077,23 @@ async def load_mcp_tools( ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError: + except TimeoutError: logger.error( "MCP connector %d timed out after %ds during discovery", - task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS, + task["connector_id"], + _MCP_DISCOVERY_TIMEOUT_SECONDS, ) return [] except Exception as e: logger.exception( "Failed to load tools from MCP connector %d: %s", - task["connector_id"], e, + task["connector_id"], + e, ) return [] results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) - tools: list[StructuredTool] = [ - tool for sublist in results for tool in sublist - ] + tools: list[StructuredTool] = [tool for sublist in results for tool in sublist] _mcp_tools_cache[search_space_id] = (now, tools) @@ -1063,7 +1101,9 @@ async def load_mcp_tools( oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) del _mcp_tools_cache[oldest_key] - logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id) + logger.info( + "Loaded %d MCP tools for search space %d", len(tools), search_space_id + ) return tools except Exception as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 85c89b114..3ac8677b9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -50,6 +50,7 @@ from .confluence import ( create_delete_confluence_page_tool, create_update_confluence_page_tool, ) +from .connected_accounts import create_get_connected_accounts_tool from .discord import ( create_list_discord_channels_tool, create_read_discord_messages_tool, @@ -78,7 +79,6 @@ from .google_drive import ( create_create_google_drive_file_tool, create_delete_google_drive_file_tool, ) -from .connected_accounts import create_get_connected_accounts_tool from .luma import ( create_create_luma_event_tool, create_list_luma_events_tool, @@ -675,10 +675,7 @@ def get_connector_gated_tools( available_connectors: list[str] | None, ) -> list[str]: """Return tool names to disable""" - if available_connectors is None: - available = set() - else: - available = set(available_connectors) + available = set() if available_connectors is None else set(available_connectors) disabled: list[str] = [] for tool_def in BUILTIN_TOOLS: @@ -829,14 +826,16 @@ async def build_tools_async( tools.extend(mcp_tools) logging.info( "Registered %d MCP tools: %s", - len(mcp_tools), [t.name for t in mcp_tools], + len(mcp_tools), + [t.name for t in mcp_tools], ) except Exception as e: logging.exception("Failed to load MCP tools: %s", e) logging.info( "Total tools for agent: %d — %s", - len(tools), [t.name for t in tools], + len(tools), + [t.name for t in tools], ) return tools diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py index f24f5502e..4345bb476 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py @@ -17,7 +17,8 @@ async def get_teams_connector( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.TEAMS_CONNECTOR, ) ) return result.scalars().first() diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py index a676595c1..d7b000853 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -35,12 +35,21 @@ def create_list_teams_channels_tool( headers = {"Authorization": f"Bearer {token}"} async with httpx.AsyncClient(timeout=20.0) as client: - teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers) + teams_resp = await client.get( + f"{GRAPH_API}/me/joinedTeams", headers=headers + ) if teams_resp.status_code == 401: - return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } if teams_resp.status_code != 200: - return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"} + return { + "status": "error", + "message": f"Graph API error: {teams_resp.status_code}", + } teams_data = teams_resp.json().get("value", []) result_teams = [] @@ -58,13 +67,19 @@ def create_list_teams_channels_tool( {"id": ch["id"], "name": ch.get("displayName", "")} for ch in ch_resp.json().get("value", []) ] - result_teams.append({ - "team_id": team_id, - "team_name": team.get("displayName", ""), - "channels": channels, - }) + result_teams.append( + { + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + } + ) - return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)} + return { + "status": "success", + "teams": result_teams, + "total_teams": len(result_teams), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py index 90896cb95..d24a7e4d3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -52,11 +52,21 @@ def create_read_teams_messages_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } if resp.status_code == 403: - return {"status": "error", "message": "Insufficient permissions to read this channel."} + return { + "status": "error", + "message": "Insufficient permissions to read this channel.", + } if resp.status_code != 200: - return {"status": "error", "message": f"Graph API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Graph API error: {resp.status_code}", + } raw_msgs = resp.json().get("value", []) messages = [] @@ -64,13 +74,15 @@ def create_read_teams_messages_tool( sender = m.get("from", {}) user_info = sender.get("user", {}) if sender else {} body = m.get("body", {}) - messages.append({ - "id": m.get("id"), - "sender": user_info.get("displayName", "Unknown"), - "content": body.get("content", ""), - "content_type": body.get("contentType", "text"), - "timestamp": m.get("createdDateTime", ""), - }) + messages.append( + { + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + } + ) return { "status": "success", diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py index ba3a515d9..fd8d00870 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -50,12 +50,19 @@ def create_send_teams_message_tool( result = request_approval( action_type="teams_send_message", tool_name="send_teams_message", - params={"team_id": team_id, "channel_id": channel_id, "content": content}, + params={ + "team_id": team_id, + "channel_id": channel_id, + "content": content, + }, context={"connector_id": connector.id}, ) if result.rejected: - return {"status": "rejected", "message": "User declined. Message was not sent."} + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } final_content = result.params.get("content", content) final_team = result.params.get("team_id", team_id) @@ -74,20 +81,27 @@ def create_send_teams_message_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } if resp.status_code == 403: return { "status": "insufficient_permissions", "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", } if resp.status_code not in (200, 201): - return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"} + return { + "status": "error", + "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", + } msg_data = resp.json() return { "status": "success", "message_id": msg_data.get("id"), - "message": f"Message sent to Teams channel.", + "message": "Message sent to Teams channel.", } except Exception as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/tool_response.py b/surfsense_backend/app/agents/new_chat/tools/tool_response.py index 5fb1864b7..8644ada5c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/tool_response.py +++ b/surfsense_backend/app/agents/new_chat/tools/tool_response.py @@ -6,7 +6,6 @@ from typing import Any class ToolResponse: - @staticmethod def success(message: str, **data: Any) -> dict[str, Any]: return {"status": "success", "message": message, **data} @@ -31,9 +30,7 @@ class ToolResponse: return {"status": "rejected", "message": message} @staticmethod - def not_found( - resource: str, identifier: str, **data: Any - ) -> dict[str, Any]: + def not_found(resource: str, identifier: str, **data: Any) -> dict[str, Any]: return { "status": "not_found", "error": f"{resource} '{identifier}' was not found.", diff --git a/surfsense_backend/app/connectors/exceptions.py b/surfsense_backend/app/connectors/exceptions.py index 32a1e7bdc..027adbb87 100644 --- a/surfsense_backend/app/connectors/exceptions.py +++ b/surfsense_backend/app/connectors/exceptions.py @@ -13,7 +13,6 @@ from typing import Any class ConnectorError(Exception): - def __init__( self, message: str, diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 8df930f30..de4e05423 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -98,7 +98,9 @@ router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(notifications_router) # Notifications with Zero sync -router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable +router.include_router( + mcp_oauth_router +) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable router.include_router(composio_router) # Composio OAuth and toolkit management router.include_router(public_chat_router) # Public chat sharing and cloning router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index e14be83d0..1abc1f1ec 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -29,7 +29,11 @@ from app.db import ( ) from app.users import current_active_user from app.utils.connector_naming import generate_unique_connector_name -from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair +from app.utils.oauth_security import ( + OAuthStateManager, + TokenEncryption, + generate_pkce_pair, +) logger = logging.getLogger(__name__) @@ -37,7 +41,9 @@ router = APIRouter() async def _fetch_account_metadata( - service_key: str, access_token: str, token_json: dict[str, Any], + service_key: str, + access_token: str, + token_json: dict[str, Any], ) -> dict[str, Any]: """Fetch display-friendly account metadata after a successful token exchange. @@ -86,7 +92,8 @@ async def _fetch_account_metadata( meta["display_name"] = whoami.get("email", "Airtable") else: logger.warning( - "Airtable whoami API returned %d (non-blocking)", resp.status_code, + "Airtable whoami API returned %d (non-blocking)", + resp.status_code, ) except Exception: @@ -98,6 +105,7 @@ async def _fetch_account_metadata( return meta + _state_manager: OAuthStateManager | None = None _token_encryption: TokenEncryption | None = None @@ -151,6 +159,7 @@ def _frontend_redirect( # /add — start MCP OAuth flow # --------------------------------------------------------------------------- + @router.get("/auth/mcp/{service}/connector/add") async def connect_mcp_service( service: str, @@ -170,9 +179,12 @@ async def connect_mcp_service( ) metadata = await discover_oauth_metadata( - svc.mcp_url, origin_override=svc.oauth_discovery_origin, + svc.mcp_url, + origin_override=svc.oauth_discovery_origin, + ) + auth_endpoint = svc.auth_endpoint_override or metadata.get( + "authorization_endpoint" ) - auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") @@ -236,7 +248,9 @@ async def connect_mcp_service( logger.info( "Generated %s MCP OAuth URL for user %s, space %s", - svc.name, user.id, space_id, + svc.name, + user.id, + space_id, ) return {"auth_url": auth_url} @@ -245,7 +259,8 @@ async def connect_mcp_service( except Exception as e: logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True) raise HTTPException( - status_code=500, detail=f"Failed to initiate {service} MCP OAuth.", + status_code=500, + detail=f"Failed to initiate {service} MCP OAuth.", ) from e @@ -253,6 +268,7 @@ async def connect_mcp_service( # /callback — handle OAuth redirect # --------------------------------------------------------------------------- + @router.get("/auth/mcp/{service}/connector/callback") async def mcp_oauth_callback( service: str, @@ -271,7 +287,9 @@ async def mcp_oauth_callback( except Exception: pass return _frontend_redirect( - space_id, error=f"{service}_mcp_oauth_denied", service=service, + space_id, + error=f"{service}_mcp_oauth_denied", + service=service, ) if not code: @@ -337,9 +355,7 @@ async def mcp_oauth_callback( expires_at = None if expires_in: - expires_at = datetime.now(UTC) + timedelta( - seconds=int(expires_in) - ) + expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in)) connector_config = { "server_config": { @@ -349,10 +365,14 @@ async def mcp_oauth_callback( "mcp_service": svc_key, "mcp_oauth": { "client_id": client_id, - "client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "client_secret": enc.encrypt_token(client_secret) + if client_secret + else "", "token_endpoint": token_endpoint, "access_token": enc.encrypt_token(access_token), - "refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None, + "refresh_token": enc.encrypt_token(refresh_token) + if refresh_token + else None, "expires_at": expires_at.isoformat() if expires_at else None, "scope": scope, }, @@ -361,15 +381,27 @@ async def mcp_oauth_callback( account_meta = await _fetch_account_metadata(svc_key, access_token, token_json) if account_meta: - _SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email", - "workspace_id", "workspace_name", "organization_name", - "organization_url_key", "cloud_id", "site_name", "base_url"} + safe_meta_keys = { + "display_name", + "team_id", + "team_name", + "user_id", + "user_email", + "workspace_id", + "workspace_name", + "organization_name", + "organization_url_key", + "cloud_id", + "site_name", + "base_url", + } for k, v in account_meta.items(): - if k in _SAFE_META_KEYS: + if k in safe_meta_keys: connector_config[k] = v logger.info( "Stored account metadata for %s: display_name=%s", - svc_key, account_meta.get("display_name", ""), + svc_key, + account_meta.get("display_name", ""), ) # ---- Re-auth path ---- @@ -400,15 +432,24 @@ async def mcp_oauth_callback( logger.info( "Re-authenticated %s MCP connector %s for user %s", - svc.name, db_connector.id, user_id, + svc.name, + db_connector.id, + user_id, ) reauth_return_url = data.get("return_url") - if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): + if ( + reauth_return_url + and reauth_return_url.startswith("/") + and not reauth_return_url.startswith("//") + ): return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" ) return _frontend_redirect( - space_id, success=True, connector_id=db_connector.id, service=service, + space_id, + success=True, + connector_id=db_connector.id, + service=service, ) # ---- New connector path ---- @@ -436,24 +477,34 @@ async def mcp_oauth_callback( except IntegrityError as e: await session.rollback() raise HTTPException( - status_code=409, detail="A connector for this service already exists.", + status_code=409, + detail="A connector for this service already exists.", ) from e _invalidate_cache(space_id) logger.info( "Created %s MCP connector %s for user %s in space %s", - svc.name, new_connector.id, user_id, space_id, + svc.name, + new_connector.id, + user_id, + space_id, ) return _frontend_redirect( - space_id, success=True, connector_id=new_connector.id, service=service, + space_id, + success=True, + connector_id=new_connector.id, + service=service, ) except HTTPException: raise except Exception as e: logger.error( - "Failed to complete %s MCP OAuth: %s", service, e, exc_info=True, + "Failed to complete %s MCP OAuth: %s", + service, + e, + exc_info=True, ) raise HTTPException( status_code=500, @@ -465,6 +516,7 @@ async def mcp_oauth_callback( # /reauth — re-authenticate an existing MCP connector # --------------------------------------------------------------------------- + @router.get("/auth/mcp/{service}/connector/reauth") async def reauth_mcp_service( service: str, @@ -491,7 +543,8 @@ async def reauth_mcp_service( ) if not result.scalars().first(): raise HTTPException( - status_code=404, detail="Connector not found or access denied", + status_code=404, + detail="Connector not found or access denied", ) try: @@ -501,9 +554,12 @@ async def reauth_mcp_service( ) metadata = await discover_oauth_metadata( - svc.mcp_url, origin_override=svc.oauth_discovery_origin, + svc.mcp_url, + origin_override=svc.oauth_discovery_origin, + ) + auth_endpoint = svc.auth_endpoint_override or metadata.get( + "authorization_endpoint" ) - auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") @@ -545,7 +601,9 @@ async def reauth_mcp_service( "service": service, "code_verifier": verifier, "mcp_client_id": client_id, - "mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "mcp_client_secret": enc.encrypt_token(client_secret) + if client_secret + else "", "mcp_token_endpoint": token_endpoint, "mcp_url": svc.mcp_url, "connector_id": connector_id, @@ -554,7 +612,9 @@ async def reauth_mcp_service( extra["return_url"] = return_url state = _get_state_manager().generate_secure_state( - space_id, user.id, **extra, + space_id, + user.id, + **extra, ) auth_params: dict[str, str] = { @@ -572,7 +632,9 @@ async def reauth_mcp_service( logger.info( "Initiating %s MCP re-auth for user %s, connector %s", - svc.name, user.id, connector_id, + svc.name, + user.id, + connector_id, ) return {"auth_url": auth_url} @@ -580,7 +642,10 @@ async def reauth_mcp_service( raise except Exception as e: logger.error( - "Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True, + "Failed to initiate %s MCP re-auth: %s", + service, + e, + exc_info=True, ) raise HTTPException( status_code=500, @@ -592,6 +657,7 @@ async def reauth_mcp_service( # Helpers # --------------------------------------------------------------------------- + def _invalidate_cache(space_id: int) -> None: try: from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 85a8658ec..091e87737 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -24,9 +24,9 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.filesystem_selection import ( ClientPlatform, - LocalFilesystemMount, FilesystemMode, FilesystemSelection, + LocalFilesystemMount, ) from app.config import config from app.db import ( diff --git a/surfsense_backend/app/routes/oauth_connector_base.py b/surfsense_backend/app/routes/oauth_connector_base.py index 0638e8f34..5b75d8519 100644 --- a/surfsense_backend/app/routes/oauth_connector_base.py +++ b/surfsense_backend/app/routes/oauth_connector_base.py @@ -9,6 +9,7 @@ Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``, from __future__ import annotations import base64 +import contextlib import logging from datetime import UTC, datetime, timedelta from typing import Any @@ -41,7 +42,6 @@ logger = logging.getLogger(__name__) class OAuthConnectorRoute: - def __init__( self, *, @@ -244,10 +244,8 @@ class OAuthConnectorRoute: if resp.status_code != 200: detail = resp.text - try: + with contextlib.suppress(Exception): detail = resp.json().get("error_description", detail) - except Exception: - pass raise HTTPException( status_code=400, detail=f"Token exchange failed: {detail}" ) @@ -430,7 +428,11 @@ class OAuthConnectorRoute: state_mgr = oauth._get_state_manager() extra: dict[str, Any] = {"connector_id": connector_id} - if return_url and return_url.startswith("/") and not return_url.startswith("//"): + if ( + return_url + and return_url.startswith("/") + and not return_url.startswith("//") + ): extra["return_url"] = return_url auth_params: dict[str, str] = { @@ -450,9 +452,7 @@ class OAuthConnectorRoute: auth_params.update(oauth.extra_auth_params) - state_encoded = state_mgr.generate_secure_state( - space_id, user.id, **extra - ) + state_encoded = state_mgr.generate_secure_state(space_id, user.id, **extra) auth_params["state"] = state_encoded auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}" @@ -489,9 +489,7 @@ class OAuthConnectorRoute: status_code=400, detail="Missing authorization code" ) if not state: - raise HTTPException( - status_code=400, detail="Missing state parameter" - ) + raise HTTPException(status_code=400, detail="Missing state parameter") state_mgr = oauth._get_state_manager() try: @@ -552,7 +550,11 @@ class OAuthConnectorRoute: db_connector.id, user_id, ) - if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): + if ( + reauth_return_url + and reauth_return_url.startswith("/") + and not reauth_return_url.startswith("//") + ): return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" ) @@ -603,7 +605,8 @@ class OAuthConnectorRoute: except IntegrityError as e: await session.rollback() raise HTTPException( - status_code=409, detail="A connector for this service already exists." + status_code=409, + detail="A connector for this service already exists.", ) from e logger.info( diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index d42a7fa1a..9037d275a 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -3092,7 +3092,7 @@ async def trust_mcp_tool( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, SearchSourceConnector.user_id == user.id, - cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), ) ) connector = result.scalars().first() @@ -3147,7 +3147,7 @@ async def untrust_mcp_tool( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, SearchSourceConnector.user_id == user.id, - cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), ) ) connector = result.scalars().first() diff --git a/surfsense_backend/app/services/mcp_oauth/discovery.py b/surfsense_backend/app/services/mcp_oauth/discovery.py index b0f3fef2a..dc21443bc 100644 --- a/surfsense_backend/app/services/mcp_oauth/discovery.py +++ b/surfsense_backend/app/services/mcp_oauth/discovery.py @@ -55,7 +55,9 @@ async def register_client( async with httpx.AsyncClient(follow_redirects=True) as client: resp = await client.post( - registration_endpoint, json=payload, timeout=timeout, + registration_endpoint, + json=payload, + timeout=timeout, ) resp.raise_for_status() return resp.json() diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 49bc74d3d..835d70184 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -70,12 +70,14 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { "createJiraIssue", "editJiraIssue", ], - readonly_tools=frozenset({ - "getAccessibleAtlassianResources", - "searchJiraIssuesUsingJql", - "getVisibleJiraProjects", - "getJiraProjectIssueTypesMetadata", - }), + readonly_tools=frozenset( + { + "getAccessibleAtlassianResources", + "searchJiraIssuesUsingJql", + "getVisibleJiraProjects", + "getJiraProjectIssueTypesMetadata", + } + ), account_metadata_keys=["cloud_id", "site_name", "base_url"], ), "clickup": MCPServiceConfig( @@ -99,15 +101,23 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { auth_endpoint_override="https://slack.com/oauth/v2_user/authorize", token_endpoint_override="https://slack.com/api/oauth.v2.user.access", scopes=[ - "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", - "channels:history", "groups:history", "mpim:history", "im:history", + "search:read.public", + "search:read.private", + "search:read.mpim", + "search:read.im", + "channels:history", + "groups:history", + "mpim:history", + "im:history", ], allowed_tools=[ "slack_search_channels", "slack_read_channel", "slack_read_thread", ], - readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}), + readonly_tools=frozenset( + {"slack_search_channels", "slack_read_channel", "slack_read_thread"} + ), # TODO: oauth.v2.user.access only returns team.id, not team.name. # To populate team_name, either add "team:read" scope and call # GET /api/team.info during OAuth callback, or switch to oauth.v2.access. @@ -127,7 +137,9 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { "list_tables_for_base", "list_records_for_table", ], - readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}), + readonly_tools=frozenset( + {"list_bases", "list_tables_for_base", "list_records_for_table"} + ), account_metadata_keys=["user_id", "user_email"], ), } @@ -136,20 +148,22 @@ _CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = { svc.connector_type: svc for svc in MCP_SERVICES.values() } -LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({ - SearchSourceConnectorType.SLACK_CONNECTOR, - SearchSourceConnectorType.TEAMS_CONNECTOR, - SearchSourceConnectorType.LINEAR_CONNECTOR, - SearchSourceConnectorType.JIRA_CONNECTOR, - SearchSourceConnectorType.CLICKUP_CONNECTOR, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.AIRTABLE_CONNECTOR, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - SearchSourceConnectorType.DISCORD_CONNECTOR, - SearchSourceConnectorType.LUMA_CONNECTOR, -}) +LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset( + { + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnectorType.LINEAR_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CLICKUP_CONNECTOR, + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.AIRTABLE_CONNECTOR, + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnectorType.LUMA_CONNECTOR, + } +) def get_service(key: str) -> MCPServiceConfig | None: diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py index 8fbdad269..0fc4f30f4 100644 --- a/surfsense_backend/app/services/obsidian_plugin_indexer.py +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -156,7 +156,9 @@ async def _extract_binary_attachment_markdown( try: raw_bytes = base64.b64decode(payload.binary_base64, validate=True) except Exception: - logger.warning("obsidian attachment payload had invalid base64: %s", payload.path) + logger.warning( + "obsidian attachment payload had invalid base64: %s", payload.path + ) return "", {"attachment_extraction_status": "invalid_binary_payload"} suffix = f".{payload.extension.lstrip('.')}" @@ -180,7 +182,10 @@ async def _extract_binary_attachment_markdown( return result.markdown_content, metadata except Exception as exc: logger.warning( - "obsidian attachment ETL failed for %s: %s", payload.path, exc, exc_info=True + "obsidian attachment ETL failed for %s: %s", + payload.path, + exc, + exc_info=True, ) return "", { "attachment_extraction_status": "etl_failed", diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5a6117808..7239c57a5 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -31,7 +31,6 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.filesystem_selection import FilesystemSelection -from app.config import config from app.agents.new_chat.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, @@ -182,9 +181,9 @@ def _tool_output_has_error(tool_output: Any) -> bool: if tool_output.get("error"): return True result = tool_output.get("result") - if isinstance(result, str) and result.strip().lower().startswith("error:"): - return True - return False + return bool( + isinstance(result, str) and result.strip().lower().startswith("error:") + ) if isinstance(tool_output, str): return tool_output.strip().lower().startswith("error:") return False @@ -230,7 +229,9 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: "stage": stage, "request_id": result.request_id or "unknown", "turn_id": result.turn_id or "unknown", - "chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown", + "chat_id": result.turn_id.split(":", 1)[0] + if ":" in result.turn_id + else "unknown", "filesystem_mode": result.filesystem_mode, "client_platform": result.client_platform, "intent_detected": result.intent_detected, @@ -242,7 +243,9 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: "commit_gate_reason": result.commit_gate_reason or None, } payload.update(extra) - _perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)) + _perf_log.info( + "[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False) + ) async def _stream_agent_events( @@ -1289,7 +1292,8 @@ async def _stream_agent_events( result.intent_detected = intent_value if ( isinstance(intent_value, str) - and intent_value in ( + and intent_value + in ( "chat_only", "file_write", "file_read", @@ -1308,18 +1312,17 @@ async def _stream_agent_events( result.commit_gate_passed, result.commit_gate_reason = ( _evaluate_file_contract_outcome(result) ) - if not result.commit_gate_passed: - if _contract_enforcement_active(result): - gate_notice = ( - "I could not complete the requested file write because no successful " - "write_file/edit_file operation was confirmed." - ) - gate_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(gate_text_id) - yield streaming_service.format_text_delta(gate_text_id, gate_notice) - yield streaming_service.format_text_end(gate_text_id) - yield streaming_service.format_terminal_info(gate_notice, "error") - accumulated_text = gate_notice + if not result.commit_gate_passed and _contract_enforcement_active(result): + gate_notice = ( + "I could not complete the requested file write because no successful " + "write_file/edit_file operation was confirmed." + ) + gate_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(gate_text_id) + yield streaming_service.format_text_delta(gate_text_id, gate_notice) + yield streaming_service.format_text_end(gate_text_id) + yield streaming_service.format_terminal_info(gate_notice, "error") + accumulated_text = gate_notice else: result.commit_gate_passed = True result.commit_gate_reason = "" diff --git a/surfsense_backend/app/utils/async_retry.py b/surfsense_backend/app/utils/async_retry.py index c3bdd5386..a56f6550a 100644 --- a/surfsense_backend/app/utils/async_retry.py +++ b/surfsense_backend/app/utils/async_retry.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import logging from collections.abc import Callable from typing import TypeVar @@ -32,9 +33,7 @@ F = TypeVar("F", bound=Callable) def _is_retryable(exc: BaseException) -> bool: if isinstance(exc, ConnectorError): return exc.retryable - if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)): - return True - return False + return bool(isinstance(exc, (httpx.TimeoutException, httpx.ConnectError))) def build_retry( @@ -86,10 +85,8 @@ def raise_for_status( retry_after_raw = response.headers.get("Retry-After") retry_after: float | None = None if retry_after_raw: - try: + with contextlib.suppress(ValueError, TypeError): retry_after = float(retry_after_raw) - except (ValueError, TypeError): - pass raise ConnectorRateLimitError( f"{service} rate limited (429)", service=service, diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 889bf1464..99c8243a5 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -233,7 +233,10 @@ async def generate_unique_connector_name( if identifier: name = f"{base} - {identifier}" return await ensure_unique_connector_name( - session, name, search_space_id, user_id, + session, + name, + search_space_id, + user_id, ) count = await count_connectors_of_type( diff --git a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py index 41779a570..22f6c6de5 100644 --- a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py +++ b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py @@ -499,7 +499,9 @@ class TestWireContractSmoke: "app.routes.obsidian_plugin_routes.upsert_note", new=AsyncMock(return_value=fake_doc), ) as upsert_mock, - patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock, + patch( + "app.routes.obsidian_plugin_routes._queue_obsidian_attachment" + ) as queue_mock, ): sync_resp = await obsidian_sync( SyncBatchRequest( @@ -548,7 +550,9 @@ class TestWireContractSmoke: "app.routes.obsidian_plugin_routes.upsert_note", new=AsyncMock(return_value=fake_doc), ), - patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock, + patch( + "app.routes.obsidian_plugin_routes._queue_obsidian_attachment" + ) as queue_mock, ): sync_resp = await obsidian_sync( SyncBatchRequest( @@ -600,7 +604,9 @@ class TestWireContractSmoke: "app.routes.obsidian_plugin_routes.upsert_note", new=AsyncMock(return_value=fake_doc), ), - patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock, + patch( + "app.routes.obsidian_plugin_routes._queue_obsidian_attachment" + ) as queue_mock, ): sync_resp = await obsidian_sync( SyncBatchRequest( @@ -619,7 +625,5 @@ class TestWireContractSmoke: items_by_path = {it.path: it for it in sync_resp.items} assert items_by_path["ok.md"].status == "ok" assert items_by_path["image.png"].status == "error" - assert "does not match extension" in ( - items_by_path["image.png"].error or "" - ) + assert "does not match extension" in (items_by_path["image.png"].error or "") queue_mock.assert_not_called() diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py index 673331b0a..7fd3fe4a7 100644 --- a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py +++ b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py @@ -45,9 +45,7 @@ async def test_file_write_intent_injects_contract_message(): @pytest.mark.asyncio async def test_non_write_intent_does_not_inject_contract_message(): - llm = _FakeLLM( - '{"intent":"file_read","confidence":0.88,"suggested_filename":null}' - ) + llm = _FakeLLM('{"intent":"file_read","confidence":0.88,"suggested_filename":null}') middleware = FileIntentMiddleware(llm=llm) original_messages = [HumanMessage(content="Read /notes.md")] state = {"messages": original_messages, "turn_id": "abc:def"} @@ -55,7 +53,10 @@ async def test_non_write_intent_does_not_inject_contract_message(): result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] assert result is not None - assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value + assert ( + result["file_operation_contract"]["intent"] + == FileOperationIntent.FILE_READ.value + ) assert "messages" not in result @@ -211,4 +212,3 @@ def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> ) assert resolved == "/var/log/surfsense/notes.md" - diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py index d00365032..cca15e789 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py @@ -2,11 +2,11 @@ from pathlib import Path import pytest +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( MultiRootLocalFolderBackend, ) -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware pytestmark = pytest.mark.unit diff --git a/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py b/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py index 7ab3c52e0..20795c739 100644 --- a/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py +++ b/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py @@ -15,7 +15,6 @@ from app.services.obsidian_plugin_indexer import ( _require_extracted_attachment_content, ) - _FAKE_PNG_B64 = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode("ascii") @@ -102,9 +101,7 @@ async def test_extract_binary_attachment_markdown_uses_etl(monkeypatch) -> None: mime_type="application/pdf", ) - async def _fake_run_etl_extract( # noqa: ANN001 - *, file_path, filename, vision_llm - ): + async def _fake_run_etl_extract(*, file_path, filename, vision_llm): assert filename == "spec.pdf" assert file_path assert vision_llm is None @@ -216,7 +213,7 @@ def test_note_payload_rejects_markdown_with_binary_fields() -> None: def test_require_extracted_attachment_content_rejects_empty_content() -> None: with pytest.raises( - RuntimeError, match="Attachment extraction failed for assets/img.png" + RuntimeError, match=r"Attachment extraction failed for assets/img\.png" ): _require_extracted_attachment_content( content=" ", diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index f4adc3d73..034aa484c 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -45,4 +45,3 @@ def test_contract_enforcement_local_only(): result.filesystem_mode = "cloud" assert not _contract_enforcement_active(result) - diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 06f3bf79f..9f569398e 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -45,8 +45,8 @@ import { } from "@/components/assistant-ui/token-usage-context"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useMessagesSync } from "@/hooks/use-messages-sync"; -import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { @@ -661,8 +661,7 @@ export default function NewChatPage() { const selection = await getAgentFilesystemSelection(searchSpaceId); if ( selection.filesystem_mode === "desktop_local_folder" && - (!selection.local_filesystem_mounts || - selection.local_filesystem_mounts.length === 0) + (!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0) ) { toast.error("Select a local folder before using Local Folder mode."); return; @@ -842,14 +841,7 @@ export default function NewChatPage() { }); } else { const tcId = `interrupt-${action.name}`; - addToolCall( - contentPartsState, - toolsWithUI, - tcId, - action.name, - action.args, - true - ); + addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); updateToolCall(contentPartsState, tcId, { result: { __interrupt__: true, ...interruptData }, }); @@ -1189,14 +1181,7 @@ export default function NewChatPage() { }); } else { const tcId = `interrupt-${action.name}`; - addToolCall( - contentPartsState, - toolsWithUI, - tcId, - action.name, - action.args, - true - ); + addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); updateToolCall(contentPartsState, tcId, { result: { __interrupt__: true, diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx index 6207457c4..12a7d00f0 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent.tsx @@ -111,9 +111,7 @@ function HotkeyRow({ } > {recording ? ( - - Press hotkeys... - + Press hotkeys... ) : ( )} @@ -155,7 +153,9 @@ export function DesktopShortcutsContent() { if (!api) { return (
-

Hotkeys are only available in the SurfSense desktop app.

+

+ Hotkeys are only available in the SurfSense desktop app. +

); } @@ -178,28 +178,26 @@ export function DesktopShortcutsContent() { updateShortcut(key, DEFAULT_SHORTCUTS[key]); }; - return ( - shortcutsLoaded ? ( -
-
- {HOTKEY_ROWS.map((row) => ( - updateShortcut(row.key, accel)} - onReset={() => resetShortcut(row.key)} - /> - ))} -
+ return shortcutsLoaded ? ( +
+
+ {HOTKEY_ROWS.map((row) => ( + updateShortcut(row.key, accel)} + onReset={() => resetShortcut(row.key)} + /> + ))}
- ) : ( -
- -
- ) +
+ ) : ( +
+ +
); } diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 451143949..c64eb65f8 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -24,7 +24,12 @@ const isGoogleAuth = AUTH_TYPE === "GOOGLE"; type ShortcutKey = "generalAssist" | "quickAsk" | "autocomplete"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; -const HOTKEY_ROWS: Array<{ key: ShortcutKey; label: string; description: string; icon: React.ElementType }> = [ +const HOTKEY_ROWS: Array<{ + key: ShortcutKey; + label: string; + description: string; + icon: React.ElementType; +}> = [ { key: "generalAssist", label: "General Assist", @@ -369,7 +374,9 @@ export default function DesktopLoginPage() { )} diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 66333a9ef..32943142a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -123,9 +123,9 @@ export const ConnectorIndicator = forwardRef ) : viewingMCPList ? ( - handleDisconnectFromList(connector, () => refreshConnectors())} - onAddAccount={handleAddNewMCPFromList} - addButtonText="Add New MCP Server" - /> + + handleDisconnectFromList(connector, () => refreshConnectors()) + } + onAddAccount={handleAddNewMCPFromList} + addButtonText="Add New MCP Server" + /> ) : viewingAccountsType ? ( - handleDisconnectFromList(connector, () => refreshConnectors())} - onAddAccount={() => { + + handleDisconnectFromList(connector, () => refreshConnectors()) + } + onAddAccount={() => { // Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS const oauthConnector = OAUTH_CONNECTORS.find( diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx index fc9812240..d9a740af2 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx @@ -213,13 +213,13 @@ export const MCPConnectForm: FC = ({ onSubmit, isSubmitting }) className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > {isTesting ? ( - <> - - Testing Connection... - - ) : ( - "Test Connection" - )} + <> + + Testing Connection... + + ) : ( + "Test Connection" + )}
diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx index d6f60e824..97b5de675 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx @@ -218,13 +218,13 @@ export const MCPConfig: FC = ({ connector, onConfigChange, onNam className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > {isTesting ? ( - <> - - Testing Connection... - - ) : ( - "Test Connection" - )} + <> + + Testing Connection... + + ) : ( + "Test Connection" + )}
diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx index e96ddfd29..06ce21dae 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx @@ -18,9 +18,9 @@ export const TeamsConfig: FC = () => {

Microsoft Teams Access

- Your agent can search and read messages from Teams channels you have access to, - and send messages on your behalf. Make sure you're a member of the teams - you want to interact with. + Your agent can search and read messages from Teams channels you have access to, and send + messages on your behalf. Make sure you're a member of the teams you want to interact + with.

diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index b2b40dfd6..c104f140a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -16,7 +16,7 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants"; +import { getReauthEndpoint, LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { MCPServiceConfig } from "../components/mcp-service-config"; import { getConnectorConfigComponent } from "../index"; @@ -314,8 +314,7 @@ export const ConnectorEditView: FC = ({ {connector.is_indexable && (() => { - const isGoogleDrive = - connector.connector_type === "GOOGLE_DRIVE_CONNECTOR"; + const isGoogleDrive = connector.connector_type === "GOOGLE_DRIVE_CONNECTOR"; const isComposioGoogleDrive = connector.connector_type === "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"; const requiresFolderSelection = isGoogleDrive || isComposioGoogleDrive; @@ -327,8 +326,7 @@ export const ConnectorEditView: FC = ({ (connector.config?.selected_files as | Array<{ id: string; name: string }> | undefined) || []; - const hasItemsSelected = - selectedFolders.length > 0 || selectedFiles.length > 0; + const hasItemsSelected = selectedFolders.length > 0 || selectedFiles.length > 0; const isDisabled = requiresFolderSelection && !hasItemsSelected; return ( @@ -380,8 +378,8 @@ export const ConnectorEditView: FC = ({ {/* Fixed Footer - Action buttons */}
- {showDisconnectConfirm ? ( -
+ {showDisconnectConfirm ? ( +
{isLive ? "Your agent will lose access to this service." diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx index 690333523..982b0be11 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx @@ -12,7 +12,10 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import { LIVE_CONNECTOR_TYPES, type IndexingConfigState } from "../../constants/connector-constants"; +import { + type IndexingConfigState, + LIVE_CONNECTOR_TYPES, +} from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { getConnectorConfigComponent } from "../index"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx index fe9aab14f..755086ba5 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx @@ -9,7 +9,11 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; import { cn } from "@/lib/utils"; -import { COMPOSIO_CONNECTORS, LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS } from "../constants/connector-constants"; +import { + COMPOSIO_CONNECTORS, + LIVE_CONNECTOR_TYPES, + OAUTH_CONNECTORS, +} from "../constants/connector-constants"; import { getDocumentCountForConnector } from "../utils/connector-document-mapping"; import { getConnectorDisplayName } from "./all-connectors-tab"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index b3c087599..8aee7e005 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -13,7 +13,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants"; +import { getReauthEndpoint, LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { getConnectorDisplayName } from "../tabs/all-connectors-tab"; @@ -182,11 +182,14 @@ export const ConnectorAccountsListView: FC = ({
) : (
- {typeConnectors.map((connector) => { - const isIndexing = indexingConnectorIds.has(connector.id); - const connectorReauthEndpoint = getReauthEndpoint(connector); - const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true; - const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config); + {typeConnectors.map((connector) => { + const isIndexing = indexingConnectorIds.has(connector.id); + const connectorReauthEndpoint = getReauthEndpoint(connector); + const isAuthExpired = + !!connectorReauthEndpoint && connector.config?.auth_expired === true; + const isLive = + LIVE_CONNECTOR_TYPES.has(connector.connector_type) || + Boolean(connector.config?.server_config); return (
= ({

) : null}
- {isAuthExpired ? ( - - ) : isLive && onDisconnect ? ( - confirmDisconnectId === connector.id ? ( -
+ {isAuthExpired ? ( + + ) : isLive && onDisconnect ? ( + confirmDisconnectId === connector.id ? ( +
+ + +
+ ) : ( - -
+ ) ) : ( - ) - ) : ( - - )} + )}
); })} diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 2707e8956..8bb228580 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -20,7 +20,6 @@ import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; -import { useElectronAPI } from "@/hooks/use-platform"; import { Skeleton } from "@/components/ui/skeleton"; import { Table, @@ -30,6 +29,7 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; +import { useElectronAPI } from "@/hooks/use-platform"; import { cn } from "@/lib/utils"; function MarkdownCodeBlockSkeleton() { @@ -493,10 +493,7 @@ const defaultComponents = memoizeMarkdownComponents({ const mounts = (await electronAPI.getAgentFilesystemMounts( resolvedSearchSpaceId )) as AgentFilesystemMount[]; - resolvedLocalPath = normalizeLocalVirtualPathForEditor( - inlineValue, - mounts - ); + resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts); } catch { // Fall back to the raw inline path if mount lookup fails. } diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 2fa980d27..3b69ae6e0 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -248,7 +248,15 @@ export function EditorPanelContent({ doFetch().catch(() => {}); return () => controller.abort(); - }, [documentId, electronAPI, isLocalFileMode, localFilePath, resolveLocalVirtualPath, searchSpaceId, title]); + }, [ + documentId, + electronAPI, + isLocalFileMode, + localFilePath, + resolveLocalVirtualPath, + searchSpaceId, + title, + ]); useEffect(() => { return () => { @@ -282,69 +290,77 @@ export function EditorPanelContent({ } }, [editorDoc?.source_markdown]); - const handleSave = useCallback(async (_options?: { silent?: boolean }) => { - setSaving(true); - try { - if (isLocalFileMode) { - if (!localFilePath) { - throw new Error("Missing local file path"); + const handleSave = useCallback( + async (_options?: { silent?: boolean }) => { + setSaving(true); + try { + if (isLocalFileMode) { + if (!localFilePath) { + throw new Error("Missing local file path"); + } + if (!electronAPI?.writeAgentLocalFileText) { + throw new Error("Local file editor is available only in desktop mode."); + } + const resolvedLocalPath = await resolveLocalVirtualPath(localFilePath); + const contentToSave = markdownRef.current; + const writeResult = await electronAPI.writeAgentLocalFileText( + resolvedLocalPath, + contentToSave, + searchSpaceId + ); + if (!writeResult.ok) { + throw new Error(writeResult.error || "Failed to save local file"); + } + setEditorDoc((prev) => (prev ? { ...prev, source_markdown: contentToSave } : prev)); + setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current); + return true; } - if (!electronAPI?.writeAgentLocalFileText) { - throw new Error("Local file editor is available only in desktop mode."); + if (!searchSpaceId || !documentId) { + throw new Error("Missing document context"); } - const resolvedLocalPath = await resolveLocalVirtualPath(localFilePath); - const contentToSave = markdownRef.current; - const writeResult = await electronAPI.writeAgentLocalFileText( - resolvedLocalPath, - contentToSave, - searchSpaceId + const token = getBearerToken(); + if (!token) { + toast.error("Please login to save"); + redirectToLogin(); + return; + } + const response = await authenticatedFetch( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ source_markdown: markdownRef.current }), + } ); - if (!writeResult.ok) { - throw new Error(writeResult.error || "Failed to save local file"); + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ detail: "Failed to save document" })); + throw new Error(errorData.detail || "Failed to save document"); } - setEditorDoc((prev) => - prev ? { ...prev, source_markdown: contentToSave } : prev - ); - setEditedMarkdown(markdownRef.current === contentToSave ? null : markdownRef.current); + + setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); + setEditedMarkdown(null); + toast.success("Document saved! Reindexing in background..."); return true; + } catch (err) { + console.error("Error saving document:", err); + toast.error(err instanceof Error ? err.message : "Failed to save document"); + return false; + } finally { + setSaving(false); } - if (!searchSpaceId || !documentId) { - throw new Error("Missing document context"); - } - const token = getBearerToken(); - if (!token) { - toast.error("Please login to save"); - redirectToLogin(); - return; - } - const response = await authenticatedFetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, - { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ source_markdown: markdownRef.current }), - } - ); - - if (!response.ok) { - const errorData = await response - .json() - .catch(() => ({ detail: "Failed to save document" })); - throw new Error(errorData.detail || "Failed to save document"); - } - - setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); - setEditedMarkdown(null); - toast.success("Document saved! Reindexing in background..."); - return true; - } catch (err) { - console.error("Error saving document:", err); - toast.error(err instanceof Error ? err.message : "Failed to save document"); - return false; - } finally { - setSaving(false); - } - }, [documentId, electronAPI, isLocalFileMode, localFilePath, resolveLocalVirtualPath, searchSpaceId]); + }, + [ + documentId, + electronAPI, + isLocalFileMode, + localFilePath, + resolveLocalVirtualPath, + searchSpaceId, + ] + ); const isEditableType = editorDoc ? (editorRenderMode === "source_code" || @@ -594,9 +610,7 @@ export function EditorPanelContent({ } }} > - + Download .md @@ -626,7 +640,7 @@ export function EditorPanelContent({
) : isEditableType ? ( ; } diff --git a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx index bdda0263d..346fe0378 100644 --- a/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx +++ b/surfsense_web/components/editor/plugins/fixed-toolbar-kit.tsx @@ -1,7 +1,6 @@ "use client"; -import { createPlatePlugin } from "platejs/react"; -import { useEditorReadOnly } from "platejs/react"; +import { createPlatePlugin, useEditorReadOnly } from "platejs/react"; import { useEditorSave } from "@/components/editor/editor-save-context"; import { FixedToolbar } from "@/components/ui/fixed-toolbar"; diff --git a/surfsense_web/components/editor/source-code-editor.tsx b/surfsense_web/components/editor/source-code-editor.tsx index dd4b3bd8e..9102dffe9 100644 --- a/surfsense_web/components/editor/source-code-editor.tsx +++ b/surfsense_web/components/editor/source-code-editor.tsx @@ -1,8 +1,8 @@ "use client"; import dynamic from "next/dynamic"; -import { useEffect, useRef } from "react"; import { useTheme } from "next-themes"; +import { useEffect, useRef } from "react"; import { Spinner } from "@/components/ui/spinner"; const MonacoEditor = dynamic(() => import("@monaco-editor/react"), { diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index c26cc9b23..04bae010c 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -72,9 +72,7 @@ export function RightPanelExpandButton() { const reportOpen = reportState.isOpen && !!reportState.reportId; const editorOpen = editorState.isOpen && - (editorState.kind === "document" - ? !!editorState.documentId - : !!editorState.localFilePath); + (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen; @@ -116,9 +114,7 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { const reportOpen = reportState.isOpen && !!reportState.reportId; const editorOpen = editorState.isOpen && - (editorState.kind === "document" - ? !!editorState.documentId - : !!editorState.localFilePath); + (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; useEffect(() => { diff --git a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx index dd7520d24..cd8fca331 100644 --- a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx @@ -1,11 +1,9 @@ "use client"; -import { Folder, FolderPlus, Search, X } from "lucide-react"; import { useAtom } from "jotai"; +import { Folder, FolderPlus, Search, X } from "lucide-react"; import { useCallback, useMemo, useRef, useState } from "react"; import { localExpandedFolderKeysAtom } from "@/atoms/documents/folder.atoms"; -import { Input } from "@/components/ui/input"; -import { Separator } from "@/components/ui/separator"; import { DropdownMenu, DropdownMenuContent, @@ -14,6 +12,8 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Separator } from "@/components/ui/separator"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { LocalFilesystemBrowser } from "./LocalFilesystemBrowser"; diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index b9c174d71..0a147f7b7 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -71,7 +71,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useMediaQuery } from "@/hooks/use-media-query"; -import { usePlatform, useElectronAPI } from "@/hooks/use-platform"; +import { useElectronAPI, usePlatform } from "@/hooks/use-platform"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service"; @@ -208,7 +208,8 @@ function AuthenticatedDocumentsSidebarBase({ const [watchedFolderIds, setWatchedFolderIds] = useState>(new Set()); const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom); const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom); - const isElectron = desktopFeaturesEnabled && typeof window !== "undefined" && !!window.electronAPI; + const isElectron = + desktopFeaturesEnabled && typeof window !== "undefined" && !!window.electronAPI; useEffect(() => { if (!electronAPI?.getAgentFilesystemSettings) return; @@ -250,10 +251,13 @@ function AuthenticatedDocumentsSidebarBase({ .filter((rootPath, index, allPaths) => allPaths.indexOf(rootPath) === index) .slice(0, MAX_LOCAL_FILESYSTEM_ROOTS); if (nextLocalRootPaths.length === localRootPaths.length) return; - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: nextLocalRootPaths, - }, searchSpaceId); + const updated = await electronAPI.setAgentFilesystemSettings( + { + mode: "desktop_local_folder", + localRootPaths: nextLocalRootPaths, + }, + searchSpaceId + ); setFilesystemSettings(updated); }, [electronAPI, localRootPaths, searchSpaceId] @@ -282,10 +286,13 @@ function AuthenticatedDocumentsSidebarBase({ const handleRemoveFilesystemRoot = useCallback( async (rootPathToRemove: string) => { if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove), - }, searchSpaceId); + const updated = await electronAPI.setAgentFilesystemSettings( + { + mode: "desktop_local_folder", + localRootPaths: localRootPaths.filter((rootPath) => rootPath !== rootPathToRemove), + }, + searchSpaceId + ); setFilesystemSettings(updated); }, [electronAPI, localRootPaths, searchSpaceId] @@ -293,19 +300,25 @@ function AuthenticatedDocumentsSidebarBase({ const handleClearFilesystemRoots = useCallback(async () => { if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: "desktop_local_folder", - localRootPaths: [], - }, searchSpaceId); + const updated = await electronAPI.setAgentFilesystemSettings( + { + mode: "desktop_local_folder", + localRootPaths: [], + }, + searchSpaceId + ); setFilesystemSettings(updated); }, [electronAPI, searchSpaceId]); const handleFilesystemTabChange = useCallback( async (tab: "cloud" | "local") => { if (!electronAPI?.setAgentFilesystemSettings) return; - const updated = await electronAPI.setAgentFilesystemSettings({ - mode: tab === "cloud" ? "cloud" : "desktop_local_folder", - }, searchSpaceId); + const updated = await electronAPI.setAgentFilesystemSettings( + { + mode: tab === "cloud" ? "cloud" : "desktop_local_folder", + }, + searchSpaceId + ); setFilesystemSettings(updated); }, [electronAPI, searchSpaceId] @@ -552,7 +565,9 @@ function AuthenticatedDocumentsSidebarBase({ if (!electronAPI) return; const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; - const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id); + const matched = watchedFolders.find( + (wf: WatchedFolderEntry) => wf.rootFolderId === folder.id + ); if (!matched) { toast.error("This folder is not being watched"); return; @@ -582,7 +597,9 @@ function AuthenticatedDocumentsSidebarBase({ if (!electronAPI) return; const watchedFolders = (await electronAPI.getWatchedFolders()) as WatchedFolderEntry[]; - const matched = watchedFolders.find((wf: WatchedFolderEntry) => wf.rootFolderId === folder.id); + const matched = watchedFolders.find( + (wf: WatchedFolderEntry) => wf.rootFolderId === folder.id + ); if (!matched) { toast.error("This folder is not being watched"); return; @@ -1015,7 +1032,8 @@ function AuthenticatedDocumentsSidebarBase({ }, [open, onOpenChange, isMobile, setRightPanelCollapsed]); const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings; - const currentFilesystemTab = filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; + const currentFilesystemTab = + filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; const showCloudSkeleton = currentFilesystemTab === "cloud" && (zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete"); @@ -1331,8 +1349,8 @@ function AuthenticatedDocumentsSidebarBase({ Trust this workspace? - Local mode can read and edit files inside the folders you select. Continue only if - you trust this workspace and its contents. + Local mode can read and edit files inside the folders you select. Continue only if you + trust this workspace and its contents. {pendingLocalPath && ( diff --git a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx index 6bfb1d3f1..19c47d605 100644 --- a/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx +++ b/surfsense_web/components/layout/ui/sidebar/LocalFilesystemBrowser.tsx @@ -141,7 +141,9 @@ export function LocalFilesystemBrowser({ }: LocalFilesystemBrowserProps) { const electronAPI = useElectronAPI(); const [rootStateMap, setRootStateMap] = useState>({}); - const [internalExpandedFolderKeys, setInternalExpandedFolderKeys] = useState>(new Set()); + const [internalExpandedFolderKeys, setInternalExpandedFolderKeys] = useState>( + new Set() + ); const [mountByRootKey, setMountByRootKey] = useState>(new Map()); const [mountStatus, setMountStatus] = useState("idle"); const [mountRefreshInFlight, setMountRefreshInFlight] = useState(false); @@ -188,10 +190,7 @@ export function LocalFilesystemBrowser({ } for (const { rootKey } of rootsToReload) { const nonce = reloadNonceByRoot[rootKey] ?? 0; - lastLoadedSignatureByRootRef.current.set( - rootKey, - `${searchSpaceId}:${rootKey}:${nonce}` - ); + lastLoadedSignatureByRootRef.current.set(rootKey, `${searchSpaceId}:${rootKey}:${nonce}`); } let cancelled = false; @@ -257,35 +256,37 @@ export function LocalFilesystemBrowser({ return; } - const unsubscribe = electronAPI.onAgentFilesystemTreeDirty((event: { - searchSpaceId: number | null; - reason: "watcher_event" | "safety_poll"; - rootPath: string; - changedPath: string | null; - timestamp: number; - }) => { - if ((event.searchSpaceId ?? null) !== (searchSpaceId ?? null)) { - return; + const unsubscribe = electronAPI.onAgentFilesystemTreeDirty( + (event: { + searchSpaceId: number | null; + reason: "watcher_event" | "safety_poll"; + rootPath: string; + changedPath: string | null; + timestamp: number; + }) => { + if ((event.searchSpaceId ?? null) !== (searchSpaceId ?? null)) { + return; + } + const eventRootKey = normalizeRootPathForLookup(event.rootPath, isWindowsPlatform); + const knownRootKeys = new Set( + rootPaths.map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) + ); + if (!knownRootKeys.has(eventRootKey)) { + setReloadNonceByRoot((prev) => { + const next = { ...prev }; + for (const rootKey of knownRootKeys) { + next[rootKey] = (prev[rootKey] ?? 0) + 1; + } + return next; + }); + return; + } + setReloadNonceByRoot((prev) => ({ + ...prev, + [eventRootKey]: (prev[eventRootKey] ?? 0) + 1, + })); } - const eventRootKey = normalizeRootPathForLookup(event.rootPath, isWindowsPlatform); - const knownRootKeys = new Set( - rootPaths.map((rootPath) => normalizeRootPathForLookup(rootPath, isWindowsPlatform)) - ); - if (!knownRootKeys.has(eventRootKey)) { - setReloadNonceByRoot((prev) => { - const next = { ...prev }; - for (const rootKey of knownRootKeys) { - next[rootKey] = (prev[rootKey] ?? 0) + 1; - } - return next; - }); - return; - } - setReloadNonceByRoot((prev) => ({ - ...prev, - [eventRootKey]: (prev[eventRootKey] ?? 0) + 1, - })); - }); + ); void electronAPI.startAgentFilesystemTreeWatch({ searchSpaceId, rootPaths, @@ -378,22 +379,25 @@ export function LocalFilesystemBrowser({ }); }, [rootPaths, rootStateMap, searchQuery]); - const toggleFolder = useCallback((folderKey: string) => { - const update = (prev: Set) => { - const next = new Set(prev); - if (next.has(folderKey)) { - next.delete(folderKey); - } else { - next.add(folderKey); + const toggleFolder = useCallback( + (folderKey: string) => { + const update = (prev: Set) => { + const next = new Set(prev); + if (next.has(folderKey)) { + next.delete(folderKey); + } else { + next.add(folderKey); + } + return next; + }; + if (onExpandedFolderKeysChange) { + onExpandedFolderKeysChange(update(effectiveExpandedFolderKeys)); + return; } - return next; - }; - if (onExpandedFolderKeysChange) { - onExpandedFolderKeysChange(update(effectiveExpandedFolderKeys)); - return; - } - setInternalExpandedFolderKeys(update); - }, [effectiveExpandedFolderKeys, onExpandedFolderKeysChange]); + setInternalExpandedFolderKeys(update); + }, + [effectiveExpandedFolderKeys, onExpandedFolderKeysChange] + ); const renderFolder = useCallback( (folder: LocalFolderNode, depth: number, mount: string) => { @@ -436,9 +440,7 @@ export function LocalFilesystemBrowser({ : undefined } className={`flex h-8 w-full items-center gap-1.5 rounded-md px-2 text-left text-sm transition-colors ${ - isOpenable - ? "hover:bg-muted/60" - : "cursor-not-allowed opacity-60" + isOpenable ? "hover:bg-muted/60" : "cursor-not-allowed opacity-60" }`} style={{ paddingInlineStart: `${(depth + 1) * 12 + 22}px` }} title={ @@ -528,7 +530,10 @@ export function LocalFilesystemBrowser({ } if (state.error) { return ( -
+

Failed to load local folder

{state.error}

diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index 77668a93d..ac5463873 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -308,9 +308,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen } }} > - + Download .md diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 3f5a5fa8c..9fe9dd8da 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -8,9 +8,9 @@ import { ChevronLeft, ChevronRight, ChevronUp, - Pencil, ImageIcon, Layers, + Pencil, Plus, ScanEye, Search, @@ -741,9 +741,7 @@ export function ModelSelector({
{!isMobile && ( @@ -769,9 +767,7 @@ export function ModelSelector({
diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index ede63d902..621cf13ce 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -398,7 +398,8 @@ export function ReportPanelContent({ ); - const editingActions = showReportEditingTier && + const editingActions = + showReportEditingTier && !isReadOnly && (isEditing ? ( <> diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx index 988befdd0..a0b700c2d 100644 --- a/surfsense_web/components/settings/agent-model-manager.tsx +++ b/surfsense_web/components/settings/agent-model-manager.tsx @@ -1,15 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { - AlertCircle, - Dot, - FileText, - Info, - Pencil, - RefreshCw, - Trash2, -} from "lucide-react"; +import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; import { useMemo, useState } from "react"; import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; diff --git a/surfsense_web/components/settings/roles-manager.tsx b/surfsense_web/components/settings/roles-manager.tsx index e7dadc20f..335cfc8a9 100644 --- a/surfsense_web/components/settings/roles-manager.tsx +++ b/surfsense_web/components/settings/roles-manager.tsx @@ -5,10 +5,8 @@ import { useAtomValue } from "jotai"; import { Bot, ChevronRight, - ScanEye, - Pencil, - FileText, Earth, + FileText, Image, Logs, type LucideIcon, @@ -16,11 +14,13 @@ import { MessageSquare, Mic, MoreHorizontal, - Unplug, + Pencil, + ScanEye, Settings, Shield, SlidersHorizontal, Trash2, + Unplug, Users, Video, } from "lucide-react"; @@ -462,9 +462,19 @@ function RolesContent({ return (
+ {/* biome-ignore lint/a11y/useSemanticElements: row contains nested interactive elements (DropdownMenu); using a )} - {sidebarDocs.length > 0 && ( - - )}
{!hasModelConfigured && (
diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 00cc2d4ef..3c5a64b0e 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -24,7 +24,6 @@ import type React from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { - sidebarMentionEventAtom, sidebarSelectedDocumentsAtom, } from "@/atoms/chat/mentioned-documents.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; @@ -416,7 +415,6 @@ function AuthenticatedDocumentsSidebarBase({ const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom); const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom); - const setSidebarMentionEvent = useSetAtom(sidebarMentionEventAtom); const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]); // Folder state @@ -864,17 +862,6 @@ function AuthenticatedDocumentsSidebarBase({ const key = `${doc.document_type}:${doc.id}`; if (isMentioned) { setSidebarDocs((prev) => prev.filter((d) => `${d.document_type}:${d.id}` !== key)); - setSidebarMentionEvent({ - kind: "remove", - docs: [ - { - id: doc.id, - title: doc.title, - document_type: doc.document_type as DocumentTypeEnum, - }, - ], - nonce: Date.now(), - }); } else { setSidebarDocs((prev) => { if (prev.some((d) => `${d.document_type}:${d.id}` === key)) return prev; @@ -883,20 +870,9 @@ function AuthenticatedDocumentsSidebarBase({ { id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum }, ]; }); - setSidebarMentionEvent({ - kind: "add", - docs: [ - { - id: doc.id, - title: doc.title, - document_type: doc.document_type as DocumentTypeEnum, - }, - ], - nonce: Date.now(), - }); } }, - [setSidebarDocs, setSidebarMentionEvent] + [setSidebarDocs] ); const handleToggleFolderSelect = useCallback( @@ -918,14 +894,6 @@ function AuthenticatedDocumentsSidebarBase({ if (subtreeDocs.length === 0) return; if (selectAll) { - const existingKeys = new Set(sidebarDocs.map((d) => `${d.document_type}:${d.id}`)); - const docsToAdd = subtreeDocs - .filter((d) => !existingKeys.has(`${d.document_type}:${d.id}`)) - .map((d) => ({ - id: d.id, - title: d.title, - document_type: d.document_type as DocumentTypeEnum, - })); setSidebarDocs((prev) => { const existingDocKeys = new Set(prev.map((d) => `${d.document_type}:${d.id}`)); const newDocs = subtreeDocs @@ -937,35 +905,14 @@ function AuthenticatedDocumentsSidebarBase({ })); return newDocs.length > 0 ? [...prev, ...newDocs] : prev; }); - if (docsToAdd.length > 0) { - setSidebarMentionEvent({ - kind: "add", - docs: docsToAdd, - nonce: Date.now(), - }); - } } else { const keysToRemove = new Set(subtreeDocs.map((d) => `${d.document_type}:${d.id}`)); - const docsToRemove = sidebarDocs - .filter((d) => keysToRemove.has(`${d.document_type}:${d.id}`)) - .map((d) => ({ - id: d.id, - title: d.title, - document_type: d.document_type as DocumentTypeEnum, - })); setSidebarDocs((prev) => prev.filter((d) => !keysToRemove.has(`${d.document_type}:${d.id}`)) ); - if (docsToRemove.length > 0) { - setSidebarMentionEvent({ - kind: "remove", - docs: docsToRemove, - nonce: Date.now(), - }); - } } }, - [treeDocuments, foldersByParent, sidebarDocs, setSidebarDocs, setSidebarMentionEvent] + [treeDocuments, foldersByParent, setSidebarDocs] ); const searchFilteredDocuments = useMemo(() => { @@ -1626,7 +1573,6 @@ function AnonymousDocumentsSidebar({ const [search, setSearch] = useState(""); const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom); - const setSidebarMentionEvent = useSetAtom(sidebarMentionEventAtom); const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]); const handleToggleChatMention = useCallback( @@ -1634,17 +1580,6 @@ function AnonymousDocumentsSidebar({ const key = `${doc.document_type}:${doc.id}`; if (isMentioned) { setSidebarDocs((prev) => prev.filter((d) => `${d.document_type}:${d.id}` !== key)); - setSidebarMentionEvent({ - kind: "remove", - docs: [ - { - id: doc.id, - title: doc.title, - document_type: doc.document_type as DocumentTypeEnum, - }, - ], - nonce: Date.now(), - }); } else { setSidebarDocs((prev) => { if (prev.some((d) => `${d.document_type}:${d.id}` === key)) return prev; @@ -1653,20 +1588,9 @@ function AnonymousDocumentsSidebar({ { id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum }, ]; }); - setSidebarMentionEvent({ - kind: "add", - docs: [ - { - id: doc.id, - title: doc.title, - document_type: doc.document_type as DocumentTypeEnum, - }, - ], - nonce: Date.now(), - }); } }, - [setSidebarDocs, setSidebarMentionEvent] + [setSidebarDocs] ); const uploadedDoc = anonMode.isAnonymous ? anonMode.uploadedDoc : null; From 294c719965f9e83867ec8831994bb0ae67caac29 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 18:36:49 +0530 Subject: [PATCH 213/299] feat(mentions): implement user message rendering with mention chips for referenced documents --- .../components/assistant-ui/user-message.tsx | 94 +++++++++++++++---- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index 86863a501..fb7212119 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,11 +1,12 @@ import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; -import { CheckIcon, CopyIcon, FileText, Pencil } from "lucide-react"; +import { CheckIcon, CopyIcon, Pencil } from "lucide-react"; import Image from "next/image"; import { type FC, useState } from "react"; import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; +import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; interface AuthorMetadata { displayName: string | null; @@ -48,6 +49,19 @@ const UserAvatar: FC = ({ displayName, avatarUrl }) => { export const UserMessage: FC = () => { const messageId = useAuiState(({ message }) => message?.id); + const messageText = useAuiState(({ message }) => + (message?.content ?? []) + .map((part) => + typeof part === "object" && + part !== null && + "type" in part && + (part as { type?: string }).type === "text" && + "text" in part + ? String((part as { text?: string }).text ?? "") + : "" + ) + .join("") + ); const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; const metadata = useAuiState(({ message }) => message?.metadata); @@ -63,22 +77,12 @@ export const UserMessage: FC = () => {
- {mentionedDocs && mentionedDocs.length > 0 && ( -
- {mentionedDocs?.map((doc) => ( - - - {doc.title} - - ))} -
- )}
- + {mentionedDocs && mentionedDocs.length > 0 ? ( + + ) : ( + + )}
@@ -95,6 +99,64 @@ export const UserMessage: FC = () => { ); }; +const UserMessageWithMentionChips: FC<{ + text: string; + mentionedDocs: { id: number; title: string; document_type: string }[]; +}> = ({ text, mentionedDocs }) => { + type Segment = + | { type: "text"; value: string; start: number } + | { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number }; + + const tokens = mentionedDocs + .map((doc) => ({ doc, token: `@${doc.title}` })) + .sort((a, b) => b.token.length - a.token.length); + + const segments: Segment[] = []; + let i = 0; + let buffer = ""; + let bufferStart = 0; + while (i < text.length) { + const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); + if (tokenMatch) { + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + buffer = ""; + } + segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); + i += tokenMatch.token.length; + bufferStart = i; + continue; + } + if (!buffer) bufferStart = i; + buffer += text[i]; + i += 1; + } + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + } + + return ( + + {segments.map((segment) => + segment.type === "text" ? ( + {segment.value} + ) : ( + + + {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} + + {segment.doc.title} + + ) + )} + + ); +}; + const UserActionBar: FC = () => { const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); From 282510f93ce16bf7b76779046960baa44ed7d9fb Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 28 Apr 2026 18:47:57 +0530 Subject: [PATCH 214/299] feat(mentions): add syncEditorState function to manage editor state and mentioned documents --- .../assistant-ui/inline-mention-editor.tsx | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index 81d6cbd77..e75a840c0 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -236,6 +236,19 @@ export const InlineMentionEditor = forwardRef) => { + const docs = docsOverride + ? Array.from(docsOverride.values()) + : Array.from(mentionedDocs.values()); + const text = getText(); + const empty = text.length === 0 && docs.length === 0; + setIsEmpty(empty); + onChange?.(text, docs); + }, + [getText, mentionedDocs, onChange] + ); + // Create a chip element for a document const createChipElement = useCallback( (doc: MentionedDocument): HTMLSpanElement => { @@ -275,6 +288,7 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(docKey); + syncEditorState(next); return next; }); onDocumentRemove?.(doc.id, doc.document_type); @@ -319,7 +333,7 @@ export const InlineMentionEditor = forwardRef new Map(prev).set(docKey, mentionDoc)); + const nextDocs = new Map(mentionedDocs); + nextDocs.set(docKey, mentionDoc); // Find and remove the @query text const selection = window.getSelection(); @@ -436,25 +452,16 @@ export const InlineMentionEditor = forwardRef { - onChange(getText(), getMentionedDocuments()); - }, 0); - } + syncEditorState(nextDocs); }, [ createChipElement, focusAtEnd, - getText, - getMentionedDocuments, isSelectionInsideEditor, - onChange, + mentionedDocs, rememberSelection, restoreRememberedSelection, + syncEditorState, ] ); @@ -462,22 +469,21 @@ export const InlineMentionEditor = forwardRef { if (editorRef.current) { editorRef.current.innerHTML = ""; - setIsEmpty(true); - setMentionedDocs(new Map()); + const emptyDocs = new Map(); + setMentionedDocs(emptyDocs); + syncEditorState(emptyDocs); } - }, []); + }, [syncEditorState]); // Replace editor content with plain text and place cursor at end const setText = useCallback( (text: string) => { if (!editorRef.current) return; editorRef.current.innerText = text; - const empty = text.length === 0; - setIsEmpty(empty); - onChange?.(text, Array.from(mentionedDocs.values())); + syncEditorState(); focusAtEnd(); }, - [focusAtEnd, onChange, mentionedDocs] + [focusAtEnd, syncEditorState] ); const setDocumentChipStatus = useCallback( @@ -538,14 +544,11 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(chipKey); + syncEditorState(next); return next; }); - - const text = getText(); - const empty = text.length === 0 && mentionedDocs.size <= 1; - setIsEmpty(empty); }, - [getText, mentionedDocs.size] + [syncEditorState] ); // Expose methods via ref @@ -697,6 +700,7 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(chipKey); + syncEditorState(next); return next; }); // Notify parent that a document was removed @@ -734,6 +738,7 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(chipKey); + syncEditorState(next); return next; }); // Notify parent that a document was removed @@ -745,7 +750,7 @@ export const InlineMentionEditor = forwardRef Date: Tue, 28 Apr 2026 09:22:19 -0700 Subject: [PATCH 215/299] feat: updated agent harness --- surfsense_backend/.env.example | 39 + .../versions/130_add_agent_action_log.py | 94 ++ .../versions/131_add_document_revisions.py | 119 ++ .../132_add_agent_permission_rules.py | 82 ++ .../app/agents/new_chat/chat_deepagent.py | 459 +++++++- .../app/agents/new_chat/errors.py | 95 ++ .../app/agents/new_chat/feature_flags.py | 188 +++ .../agents/new_chat/middleware/__init__.py | 42 + .../agents/new_chat/middleware/action_log.py | 294 +++++ .../agents/new_chat/middleware/busy_mutex.py | 231 ++++ .../agents/new_chat/middleware/compaction.py | 253 +++++ .../new_chat/middleware/context_editing.py | 349 ++++++ .../new_chat/middleware/dedup_tool_calls.py | 123 +- .../agents/new_chat/middleware/doom_loop.py | 228 ++++ .../new_chat/middleware/knowledge_search.py | 79 +- .../new_chat/middleware/noop_injection.py | 133 +++ .../agents/new_chat/middleware/otel_span.py | 202 ++++ .../agents/new_chat/middleware/permission.py | 344 ++++++ .../agents/new_chat/middleware/retry_after.py | 245 ++++ .../new_chat/middleware/safe_summarization.py | 123 -- .../new_chat/middleware/skills_backends.py | 332 ++++++ .../new_chat/middleware/tool_call_repair.py | 190 ++++ .../app/agents/new_chat/permissions.py | 204 ++++ .../app/agents/new_chat/plugin_loader.py | 157 +++ .../app/agents/new_chat/plugins/__init__.py | 6 + .../new_chat/plugins/year_substituter.py | 87 ++ .../app/agents/new_chat/prompts/__init__.py | 7 + .../agents/new_chat/prompts/base/__init__.py | 1 + .../new_chat/prompts/base/agent_private.md | 7 + .../new_chat/prompts/base/agent_team.md | 9 + .../new_chat/prompts/base/citations_off.md | 16 + .../new_chat/prompts/base/citations_on.md | 90 ++ .../prompts/base/kb_only_policy_private.md | 15 + .../prompts/base/kb_only_policy_team.md | 15 + .../prompts/base/memory_protocol_private.md | 6 + .../prompts/base/memory_protocol_team.md | 6 + .../prompts/base/parameter_resolution.md | 39 + .../prompts/base/tool_routing_private.md | 16 + .../prompts/base/tool_routing_team.md | 16 + .../app/agents/new_chat/prompts/composer.py | 359 ++++++ .../new_chat/prompts/examples/__init__.py | 1 + .../prompts/examples/generate_image.md | 12 + .../prompts/examples/generate_podcast.md | 7 + .../prompts/examples/generate_report.md | 13 + .../prompts/examples/generate_resume.md | 19 + .../examples/generate_video_presentation.md | 7 + .../prompts/examples/scrape_webpage.md | 13 + .../prompts/examples/search_surfsense_docs.md | 9 + .../prompts/examples/update_memory_private.md | 16 + .../prompts/examples/update_memory_team.md | 7 + .../new_chat/prompts/examples/web_search.md | 8 + .../new_chat/prompts/providers/__init__.py | 1 + .../new_chat/prompts/providers/anthropic.md | 5 + .../new_chat/prompts/providers/default.md | 1 + .../new_chat/prompts/providers/google.md | 4 + .../prompts/providers/openai_classic.md | 5 + .../prompts/providers/openai_reasoning.md | 5 + .../new_chat/prompts/routing/__init__.py | 1 + .../agents/new_chat/prompts/routing/jira.md | 1 + .../agents/new_chat/prompts/routing/linear.md | 1 + .../agents/new_chat/prompts/routing/slack.md | 1 + .../agents/new_chat/prompts/tools/__init__.py | 1 + .../new_chat/prompts/tools/_preamble.md | 6 + .../new_chat/prompts/tools/generate_image.md | 11 + .../prompts/tools/generate_podcast.md | 15 + .../new_chat/prompts/tools/generate_report.md | 39 + .../new_chat/prompts/tools/generate_resume.md | 30 + .../tools/generate_video_presentation.md | 9 + .../new_chat/prompts/tools/scrape_webpage.md | 30 + .../prompts/tools/search_surfsense_docs.md | 7 + .../prompts/tools/update_memory_private.md | 31 + .../prompts/tools/update_memory_team.md | 26 + .../new_chat/prompts/tools/web_search.md | 18 + .../app/agents/new_chat/skills/__init__.py | 7 + .../new_chat/skills/builtin/__init__.py | 1 + .../skills/builtin/email-drafting/SKILL.md | 25 + .../skills/builtin/kb-research/SKILL.md | 23 + .../skills/builtin/meeting-prep/SKILL.md | 22 + .../skills/builtin/report-writing/SKILL.md | 23 + .../skills/builtin/slack-summary/SKILL.md | 26 + .../app/agents/new_chat/subagents/__init__.py | 26 + .../app/agents/new_chat/subagents/config.py | 427 +++++++ .../app/agents/new_chat/system_prompt.py | 1003 ++--------------- .../app/agents/new_chat/tools/invalid_tool.py | 52 + .../app/agents/new_chat/tools/registry.py | 50 + surfsense_backend/app/db.py | 196 ++++ .../app/observability/__init__.py | 7 + surfsense_backend/app/observability/otel.py | 319 ++++++ surfsense_backend/app/routes/__init__.py | 10 + .../app/routes/agent_action_log_route.py | 186 +++ .../app/routes/agent_flags_route.py | 71 ++ .../app/routes/agent_permissions_route.py | 280 +++++ .../app/routes/agent_revert_route.py | 122 ++ .../app/services/revert_service.py | 279 +++++ surfsense_backend/app/utils/async_retry.py | 2 +- .../tests/integration/harness/__init__.py | 146 +++ .../harness/test_scripted_harness.py | 53 + .../tests/unit/agents/__init__.py | 1 + .../tests/unit/agents/new_chat/__init__.py | 1 + .../unit/agents/new_chat/prompts/__init__.py | 1 + .../agents/new_chat/prompts/test_composer.py | 201 ++++ .../unit/agents/new_chat/test_action_log.py | 311 +++++ .../unit/agents/new_chat/test_busy_mutex.py | 90 ++ .../unit/agents/new_chat/test_compaction.py | 107 ++ .../agents/new_chat/test_context_editing.py | 107 ++ .../agents/new_chat/test_dedup_tool_calls.py | 132 +++ .../test_default_permissions_layering.py | 128 +++ .../unit/agents/new_chat/test_doom_loop.py | 99 ++ .../agents/new_chat/test_feature_flags.py | 120 ++ .../agents/new_chat/test_noop_injection.py | 119 ++ .../unit/agents/new_chat/test_otel_span.py | 195 ++++ .../new_chat/test_permission_middleware.py | 116 ++ .../unit/agents/new_chat/test_permissions.py | 111 ++ .../agents/new_chat/test_plugin_loader.py | 187 +++ .../unit/agents/new_chat/test_retry_after.py | 107 ++ .../agents/new_chat/test_skills_backends.py | 242 ++++ .../new_chat/test_specialized_subagents.py | 338 ++++++ .../agents/new_chat/test_tool_call_repair.py | 103 ++ .../middleware/test_dedup_hitl_tool_calls.py | 31 +- .../tests/unit/observability/__init__.py | 1 + .../tests/unit/observability/test_otel.py | 84 ++ .../unit/services/test_revert_service.py | 56 + .../components/AgentPermissionsContent.tsx | 451 ++++++++ .../components/AgentStatusContent.tsx | 309 +++++ .../atoms/agent/action-log-sheet.atom.ts | 19 + .../atoms/agent/agent-flags-query.atom.ts | 17 + .../agent-action-log/action-log-button.tsx | 50 + .../agent-action-log/action-log-item.tsx | 215 ++++ .../agent-action-log/action-log-sheet.tsx | 185 +++ .../components/assistant-ui/markdown-text.tsx | 16 +- .../components/assistant-ui/tool-fallback.tsx | 7 + .../layout/providers/LayoutDataProvider.tsx | 4 + .../components/layout/ui/header/Header.tsx | 2 + surfsense_web/components/markdown-viewer.tsx | 6 +- .../settings/user-settings-dialog.tsx | 28 + .../components/tool-ui/doom-loop-approval.tsx | 187 +++ .../lib/apis/agent-actions-api.service.ts | 64 ++ .../lib/apis/agent-flags-api.service.ts | 40 + .../lib/apis/agent-permissions-api.service.ts | 90 ++ 139 files changed, 12583 insertions(+), 1111 deletions(-) create mode 100644 surfsense_backend/alembic/versions/130_add_agent_action_log.py create mode 100644 surfsense_backend/alembic/versions/131_add_document_revisions.py create mode 100644 surfsense_backend/alembic/versions/132_add_agent_permission_rules.py create mode 100644 surfsense_backend/app/agents/new_chat/errors.py create mode 100644 surfsense_backend/app/agents/new_chat/feature_flags.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/action_log.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/compaction.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/context_editing.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/doom_loop.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/noop_injection.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/otel_span.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/permission.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/retry_after.py delete mode 100644 surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/skills_backends.py create mode 100644 surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py create mode 100644 surfsense_backend/app/agents/new_chat/permissions.py create mode 100644 surfsense_backend/app/agents/new_chat/plugin_loader.py create mode 100644 surfsense_backend/app/agents/new_chat/plugins/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/plugins/year_substituter.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/composer.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/default.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/google.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/routing/jira.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/routing/linear.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/routing/slack.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md create mode 100644 surfsense_backend/app/agents/new_chat/skills/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md create mode 100644 surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md create mode 100644 surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md create mode 100644 surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md create mode 100644 surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md create mode 100644 surfsense_backend/app/agents/new_chat/subagents/__init__.py create mode 100644 surfsense_backend/app/agents/new_chat/subagents/config.py create mode 100644 surfsense_backend/app/agents/new_chat/tools/invalid_tool.py create mode 100644 surfsense_backend/app/observability/__init__.py create mode 100644 surfsense_backend/app/observability/otel.py create mode 100644 surfsense_backend/app/routes/agent_action_log_route.py create mode 100644 surfsense_backend/app/routes/agent_flags_route.py create mode 100644 surfsense_backend/app/routes/agent_permissions_route.py create mode 100644 surfsense_backend/app/routes/agent_revert_route.py create mode 100644 surfsense_backend/app/services/revert_service.py create mode 100644 surfsense_backend/tests/integration/harness/__init__.py create mode 100644 surfsense_backend/tests/integration/harness/test_scripted_harness.py create mode 100644 surfsense_backend/tests/unit/agents/__init__.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/__init__.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_action_log.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_compaction.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_permissions.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py create mode 100644 surfsense_backend/tests/unit/observability/__init__.py create mode 100644 surfsense_backend/tests/unit/observability/test_otel.py create mode 100644 surfsense_backend/tests/unit/services/test_revert_service.py create mode 100644 surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx create mode 100644 surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx create mode 100644 surfsense_web/atoms/agent/action-log-sheet.atom.ts create mode 100644 surfsense_web/atoms/agent/agent-flags-query.atom.ts create mode 100644 surfsense_web/components/agent-action-log/action-log-button.tsx create mode 100644 surfsense_web/components/agent-action-log/action-log-item.tsx create mode 100644 surfsense_web/components/agent-action-log/action-log-sheet.tsx create mode 100644 surfsense_web/components/tool-ui/doom-loop-approval.tsx create mode 100644 surfsense_web/lib/apis/agent-actions-api.service.ts create mode 100644 surfsense_web/lib/apis/agent-flags-api.service.ts create mode 100644 surfsense_web/lib/apis/agent-permissions-api.service.ts diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 86bac0aaf..e133a2bc5 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -247,3 +247,42 @@ LANGSMITH_TRACING=true LANGSMITH_ENDPOINT=https://api.smith.langchain.com LANGSMITH_API_KEY=lsv2_pt_..... LANGSMITH_PROJECT=surfsense + + +# ============================================================================= +# OPTIONAL: New-chat agent feature flags (OpenCode-port) +# ============================================================================= +# Master kill-switch — when true, every flag below is forced OFF. +# SURFSENSE_DISABLE_NEW_AGENT_STACK=false + +# Tier 1 — Agent quality +# SURFSENSE_ENABLE_CONTEXT_EDITING=false +# SURFSENSE_ENABLE_COMPACTION_V2=false +# SURFSENSE_ENABLE_RETRY_AFTER=false +# SURFSENSE_ENABLE_MODEL_FALLBACK=false +# SURFSENSE_ENABLE_MODEL_CALL_LIMIT=false +# SURFSENSE_ENABLE_TOOL_CALL_LIMIT=false +# SURFSENSE_ENABLE_TOOL_CALL_REPAIR=false +# SURFSENSE_ENABLE_DOOM_LOOP=false # leave OFF until UI handles permission='doom_loop' + +# Tier 2 — Safety +# SURFSENSE_ENABLE_PERMISSION=false +# SURFSENSE_ENABLE_BUSY_MUTEX=false +# SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call + +# Tier 3b — Observability (also requires OTEL_EXPORTER_OTLP_ENDPOINT) +# SURFSENSE_ENABLE_OTEL=false + +# Tier 4 — Skills + subagents +# SURFSENSE_ENABLE_SKILLS=false +# SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false +# SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false + +# Tier 5 — Snapshot / revert +# SURFSENSE_ENABLE_ACTION_LOG=false +# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships + +# Tier 6 — Plugins +# SURFSENSE_ENABLE_PLUGIN_LOADER=false +# Comma-separated allowlist of plugin entry-point names +# SURFSENSE_ALLOWED_PLUGINS=year_substituter diff --git a/surfsense_backend/alembic/versions/130_add_agent_action_log.py b/surfsense_backend/alembic/versions/130_add_agent_action_log.py new file mode 100644 index 000000000..5793988cb --- /dev/null +++ b/surfsense_backend/alembic/versions/130_add_agent_action_log.py @@ -0,0 +1,94 @@ +"""130_add_agent_action_log + +Revision ID: 130 +Revises: 129 +Create Date: 2026-04-28 + +Tier 5.2 in the OpenCode-port plan. Adds the append-only ``agent_action_log`` +table that :class:`ActionLogMiddleware` writes to after every tool call. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "130" +down_revision: str | None = "129" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "agent_action_log", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "thread_id", + sa.Integer(), + sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column("turn_id", sa.String(length=64), nullable=True, index=True), + sa.Column("message_id", sa.String(length=128), nullable=True, index=True), + sa.Column("tool_name", sa.String(length=255), nullable=False, index=True), + sa.Column("args", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("result_id", sa.String(length=255), nullable=True), + sa.Column( + "reversible", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "reverse_descriptor", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + sa.Column("error", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column( + "reverse_of", + sa.Integer(), + sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + ) + op.create_index( + "ix_agent_action_log_thread_created", + "agent_action_log", + ["thread_id", "created_at"], + ) + + +def downgrade() -> None: + op.drop_index( + "ix_agent_action_log_thread_created", table_name="agent_action_log" + ) + op.drop_table("agent_action_log") diff --git a/surfsense_backend/alembic/versions/131_add_document_revisions.py b/surfsense_backend/alembic/versions/131_add_document_revisions.py new file mode 100644 index 000000000..46c6991b6 --- /dev/null +++ b/surfsense_backend/alembic/versions/131_add_document_revisions.py @@ -0,0 +1,119 @@ +"""131_add_document_revisions + +Revision ID: 131 +Revises: 130 +Create Date: 2026-04-28 + +Tier 5.1 in the OpenCode-port plan. Adds two snapshot tables: + +* ``document_revisions``: pre-mutation snapshot of NOTE/FILE/EXTENSION docs. +* ``folder_revisions``: pre-mutation snapshot of folder mkdir/move/delete. + +Both are written by :class:`KnowledgeBasePersistenceMiddleware` ahead of +state-changing tool calls and consumed by ``revert_service.revert_action``. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "131" +down_revision: str | None = "130" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "document_revisions", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "document_id", + sa.Integer(), + sa.ForeignKey("documents.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column("content_before", sa.Text(), nullable=True), + sa.Column("title_before", sa.String(), nullable=True), + sa.Column("folder_id_before", sa.Integer(), nullable=True), + sa.Column( + "chunks_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column( + "metadata_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column( + "created_by_turn_id", sa.String(length=64), nullable=True, index=True + ), + sa.Column( + "agent_action_id", + sa.Integer(), + sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + ) + + op.create_table( + "folder_revisions", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "folder_id", + sa.Integer(), + sa.ForeignKey("folders.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column("name_before", sa.String(length=255), nullable=True), + sa.Column("parent_id_before", sa.Integer(), nullable=True), + sa.Column("position_before", sa.String(length=50), nullable=True), + sa.Column( + "created_by_turn_id", sa.String(length=64), nullable=True, index=True + ), + sa.Column( + "agent_action_id", + sa.Integer(), + sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + ) + + +def downgrade() -> None: + op.drop_table("folder_revisions") + op.drop_table("document_revisions") diff --git a/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py b/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py new file mode 100644 index 000000000..0e81eacb5 --- /dev/null +++ b/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py @@ -0,0 +1,82 @@ +"""132_add_agent_permission_rules + +Revision ID: 132 +Revises: 131 +Create Date: 2026-04-28 + +Tier 2.1 in the OpenCode-port plan. Adds the persistent ``agent_permission_rules`` +table consumed by :class:`PermissionMiddleware` at agent build time. Rules +can be scoped at search-space (``user_id`` / ``thread_id`` NULL), +user-wide (``user_id`` set, ``thread_id`` NULL), or per-thread +(``thread_id`` set). +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "132" +down_revision: str | None = "131" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "agent_permission_rules", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("user.id", ondelete="CASCADE"), + nullable=True, + index=True, + ), + sa.Column( + "thread_id", + sa.Integer(), + sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=True, + index=True, + ), + sa.Column("permission", sa.String(length=255), nullable=False), + sa.Column( + "pattern", + sa.String(length=255), + nullable=False, + server_default="*", + ), + sa.Column("action", sa.String(length=16), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + sa.UniqueConstraint( + "search_space_id", + "user_id", + "thread_id", + "permission", + "pattern", + "action", + name="uq_agent_permission_rules_scope", + ), + ) + + +def downgrade() -> None: + op.drop_table("agent_permission_rules") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 61de1fffa..672570696 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -23,9 +23,16 @@ from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_v from deepagents.backends import StateBackend from deepagents.graph import BASE_AGENT_PROMPT from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware +from deepagents.middleware.skills import SkillsMiddleware from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT from langchain.agents import create_agent -from langchain.agents.middleware import TodoListMiddleware +from langchain.agents.middleware import ( + LLMToolSelectorMiddleware, + ModelCallLimitMiddleware, + ModelFallbackMiddleware, + TodoListMiddleware, + ToolCallLimitMiddleware, +) from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool @@ -33,27 +40,51 @@ from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.middleware import ( + ActionLogMiddleware, AnonymousDocumentMiddleware, + BusyMutexMiddleware, + ClearToolUsesEdit, DedupHITLToolCallsMiddleware, + DoomLoopMiddleware, FileIntentMiddleware, KnowledgeBasePersistenceMiddleware, KnowledgePriorityMiddleware, KnowledgeTreeMiddleware, MemoryInjectionMiddleware, + NoopInjectionMiddleware, + OtelSpanMiddleware, + PermissionMiddleware, + RetryAfterMiddleware, + SpillingContextEditingMiddleware, + SpillToBackendEdit, SurfSenseFilesystemMiddleware, + ToolCallNameRepairMiddleware, + build_skills_backend_factory, + create_surfsense_compaction_middleware, + default_skills_sources, ) -from app.agents.new_chat.middleware.safe_summarization import ( - create_safe_summarization_middleware, +from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.new_chat.plugin_loader import ( + PluginContext, + load_allowed_plugin_names_from_env, + load_plugin_middlewares, ) +from app.agents.new_chat.subagents import build_specialized_subagents from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, build_surfsense_system_prompt, ) +from app.agents.new_chat.tools.invalid_tool import ( + INVALID_TOOL_NAME, + invalid_tool, +) from app.agents.new_chat.tools.registry import ( + BUILTIN_TOOLS, build_tools_async, get_connector_gated_tools, ) @@ -321,6 +352,17 @@ async def create_surfsense_deep_agent( disabled_tools=modified_disabled_tools, additional_tools=list(additional_tools) if additional_tools else None, ) + + # Tier 1.6: register `invalid` tool. It is dispatched only when + # ToolCallNameRepairMiddleware rewrites a malformed call. We + # intentionally append it AFTER ``build_tools_async`` so it never + # appears in the system-prompt tool list (which is built from the + # registry, not the bound tool list). + _flags: AgentFeatureFlags = get_flags() + if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in { + t.name for t in tools + }: + tools = [*list(tools), invalid_tool] _perf_log.info( "[create_agent] build_tools_async in %.3fs (%d tools)", time.perf_counter() - _t0, @@ -397,6 +439,8 @@ async def create_surfsense_deep_agent( available_connectors=available_connectors, available_document_types=available_document_types, mentioned_document_ids=mentioned_document_ids, + max_input_tokens=_max_input_tokens, + flags=_flags, checkpointer=checkpointer, ) _perf_log.info( @@ -411,6 +455,71 @@ async def create_surfsense_deep_agent( return agent +# Tier 1.1: tools whose output is too costly / lossy to discard. Keep +# this conservative — anything listed here is *never* pruned by +# ContextEditingMiddleware. The list is filtered against actually-bound +# tool names so disabled connectors don't show up here. +_PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset( + { + "generate_report", + "generate_resume", + "generate_podcast", + "generate_video_presentation", + "generate_image", + # Read-heavy connector reads — recomputing them is expensive + "read_email", + "search_emails", + # The fallback for malformed tool calls — keep its replies visible + "invalid", + } +) + + +def _safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]: + """Return ``exclude_tools`` derived from the actually-bound tool list. + + Filters :data:`_PRUNE_PROTECTED_TOOL_NAMES` against the bound tools + so we never list tools that don't exist (would be a silent no-op). + """ + enabled = {t.name for t in tools} + return tuple(name for name in _PRUNE_PROTECTED_TOOL_NAMES if name in enabled) + + +# Tier 2.1 / cleanup: opencode `Permission.disabled` parity. Replaces the +# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE``-based gating with a +# declarative pass over :data:`BUILTIN_TOOLS`. Each tool that declares a +# ``required_connector`` not present in ``available_connectors`` gets a +# deny rule so any execution attempt short-circuits with permission_denied. +def _synthesize_connector_deny_rules( + *, + available_connectors: list[str] | None, + enabled_tool_names: set[str], +) -> list[Rule]: + """Build deny rules for tools whose required connector is not enabled. + + Source of truth is ``ToolDefinition.required_connector`` in + :data:`BUILTIN_TOOLS`. A tool only gets a deny rule when: + + 1. It is currently bound (``enabled_tool_names``). + 2. It declares a ``required_connector``. + 3. That connector is *not* in ``available_connectors``. + + This expresses the OpenCode ``Permission.disabled`` semantics + declaratively, replacing the substring-heuristic binary gating + that used to consult the hardcoded ``_CONNECTOR_TYPE_TO_SEARCHABLE`` + map. + """ + available = set(available_connectors or []) + deny: list[Rule] = [] + for tool_def in BUILTIN_TOOLS: + if tool_def.name not in enabled_tool_names: + continue + rc = tool_def.required_connector + if rc and rc not in available: + deny.append(Rule(permission=tool_def.name, pattern="*", action="deny")) + return deny + + def _build_compiled_agent_blocking( *, llm: BaseChatModel, @@ -426,6 +535,8 @@ def _build_compiled_agent_blocking( available_connectors: list[str] | None, available_document_types: list[str] | None, mentioned_document_ids: list[int] | None, + max_input_tokens: int | None, + flags: AgentFeatureFlags, checkpointer: Checkpointer, ): """Build the middleware stack and compile the agent graph synchronously. @@ -458,7 +569,7 @@ def _build_compiled_agent_blocking( created_by_id=user_id, thread_id=thread_id, ), - create_safe_summarization_middleware(llm, StateBackend), + create_surfsense_compaction_middleware(llm, StateBackend), PatchToolCallsMiddleware(), AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] @@ -470,13 +581,319 @@ def _build_compiled_agent_blocking( "middleware": gp_middleware, } + # Tier 4.3: specialized user-facing subagents (explore, report_writer, + # connector_negotiator). Registered through SubAgentMiddleware alongside + # the general-purpose spec so the parent's `task` tool can address them + # by name. Off by default until the flag flips so existing deployments + # don't see new agent types in the task tool description. + specialized_subagents: list[SubAgent] = [] + if ( + flags.enable_specialized_subagents + and not flags.disable_new_agent_stack + ): + try: + # Specialized subagents share the parent's filesystem + + # todo view so their system prompts (which promise + # ``read_file``, ``ls``, ``grep``, ``glob``, ``write_todos``) + # actually match runtime behavior. Build *fresh* instances + # rather than aliasing the parent's GP middleware to avoid + # subtle state coupling across compiled graphs. + subagent_extra_middleware: list = [ + TodoListMiddleware(), + SurfSenseFilesystemMiddleware( + backend=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + created_by_id=user_id, + thread_id=thread_id, + ), + ] + specialized_subagents = build_specialized_subagents( + tools=tools, + model=llm, + extra_middleware=subagent_extra_middleware, + ) + except Exception as exc: # pragma: no cover - defensive + logging.warning( + "Specialized subagent build failed; running without them: %s", + exc, + ) + specialized_subagents = [] + + subagent_specs: list[SubAgent] = [general_purpose_spec, *specialized_subagents] + # Main agent middleware # Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ... # before_agent hooks run in declared order; later injections sit closer to # the latest human turn. Tree (large + cacheable) is injected earliest so # provider-side prefix caching has more material to hit; FileIntent (most # actionable per-turn contract) is injected closest to the user message. + # + # ``wrap_model_call`` ordering: the FIRST middleware in the list is the + # OUTERMOST wrapper. To ensure prune executes before summarization, + # place ``SpillingContextEditingMiddleware`` before + # ``SurfSenseCompactionMiddleware`` (Tier 1.1 + 1.3). + # Compaction is the canonical token-budget defense after the + # cleanup tier removed ``SafeSummarizationMiddleware``. The Bedrock + # buffer-empty defense is folded into ``SurfSenseCompactionMiddleware``. + summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend) + _ = flags.enable_compaction_v2 # historical flag; retained for telemetry parity + + # Tier 1.1: ContextEditing prune. Trigger at 55% of model_max_input, + # earlier than summarization (~85%). When disabled, no edit runs. + context_edit_mw = None + if ( + flags.enable_context_editing + and not flags.disable_new_agent_stack + and max_input_tokens + ): + spill_edit = SpillToBackendEdit( + trigger=int(max_input_tokens * 0.55), + clear_at_least=int(max_input_tokens * 0.15), + keep=5, + exclude_tools=_safe_exclude_tools(tools), + clear_tool_inputs=True, + ) + clear_edit = ClearToolUsesEdit( + trigger=int(max_input_tokens * 0.55), + clear_at_least=int(max_input_tokens * 0.15), + keep=5, + exclude_tools=_safe_exclude_tools(tools), + clear_tool_inputs=True, + placeholder="[cleared - older tool output trimmed for context]", + ) + context_edit_mw = SpillingContextEditingMiddleware( + edits=[spill_edit, clear_edit], + backend_resolver=backend_resolver, + ) + + # Tier 1.4 / 1.8 / 1.9 / 1.10: built-in retry/fallback/limits. + retry_mw = ( + RetryAfterMiddleware(max_retries=3) + if flags.enable_retry_after and not flags.disable_new_agent_stack + else None + ) + # Fallback chain — primary is the agent's own model; we add cheap + # alternatives. Off by default; only the first call site that + # configures the chain via env should enable it. + fallback_mw: ModelFallbackMiddleware | None = None + if flags.enable_model_fallback and not flags.disable_new_agent_stack: + try: + fallback_mw = ModelFallbackMiddleware( + "openai:gpt-4o-mini", + "anthropic:claude-3-5-haiku-20241022", + ) + except Exception: + logging.warning("ModelFallbackMiddleware init failed; skipping.") + fallback_mw = None + model_call_limit_mw = ( + ModelCallLimitMiddleware( + thread_limit=120, + run_limit=80, + exit_behavior="end", + ) + if flags.enable_model_call_limit and not flags.disable_new_agent_stack + else None + ) + tool_call_limit_mw = ( + ToolCallLimitMiddleware(thread_limit=300, run_limit=80, exit_behavior="continue") + if flags.enable_tool_call_limit and not flags.disable_new_agent_stack + else None + ) + + # Tier 1.5: provider-compat _noop injection. + noop_mw = ( + NoopInjectionMiddleware() + if flags.enable_compaction_v2 and not flags.disable_new_agent_stack + else None + ) + + # Tier 1.7: tool-call name repair (lowercase + invalid fallback). + # + # ``registered_tool_names`` MUST cover every tool the model can legitimately + # call. That includes the bound ``tools`` list AND every tool provided by + # middleware in the stack — ``FilesystemMiddleware`` (read_file, ls, grep, + # glob, edit_file, write_file, execute), ``TodoListMiddleware`` + # (write_todos), ``SubAgentMiddleware`` (task), ``SkillsMiddleware`` (skill + # loaders), etc. If we only inspect ``tools`` here, every call to + # ``read_file`` / ``ls`` / ``grep`` from the model will be rewritten to + # ``invalid`` because the repair middleware doesn't recognize them. The + # built-in deepagents middleware aren't in scope yet at this point of the + # function but they're added unconditionally below, so we hard-code their + # canonical names alongside the dynamic ``tools`` set. + repair_mw = None + if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: + registered_names: set[str] = {t.name for t in tools} + # Tools owned by the standard deepagents middleware stack. + registered_names |= { + "write_todos", + "ls", + "read_file", + "write_file", + "edit_file", + "glob", + "grep", + "execute", + "task", + } + repair_mw = ToolCallNameRepairMiddleware( + registered_tool_names=registered_names, + fuzzy_match_threshold=None, # opencode parity: no fuzzy step + ) + + # Tier 1.11: doom-loop detector. Off by default until UI handles. + doom_loop_mw = ( + DoomLoopMiddleware(threshold=3) + if flags.enable_doom_loop and not flags.disable_new_agent_stack + else None + ) + + # Tier 2.1: PermissionMiddleware. Layers, earliest -> latest (last + # match wins per opencode): + # + # 1. ``surfsense_defaults`` — single ``allow */*`` rule. SurfSense + # already runs per-tool HITL (see ``tools/hitl.py``) for mutating + # connector tools, so we only want PermissionMiddleware to *deny* + # things the user has gated off; the default fallback in + # ``permissions.evaluate`` is ``ask``, which would double-prompt + # on every safe read-only call (``ls``, ``read_file``, ``grep``, + # ``glob``, ``web_search`` …) and, on resume, replay the previous + # reject decision into innocent calls. + # 2. ``connector_synthesized`` — deny rules for tools whose required + # connector is not connected to this space. Overrides #1. + # 3. (future) user-defined rules from ``agent_permission_rules`` table + # via the Agent Permissions UI. Loaded last so they override both. + permission_mw: PermissionMiddleware | None = None + if flags.enable_permission and not flags.disable_new_agent_stack: + synthesized = _synthesize_connector_deny_rules( + available_connectors=available_connectors, + enabled_tool_names={t.name for t in tools}, + ) + permission_mw = PermissionMiddleware( + rulesets=[ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + Ruleset(rules=synthesized, origin="connector_synthesized"), + ], + ) + + # Tier 5.2: ActionLogMiddleware. Off by default until the + # ``agent_action_log`` table is migrated. When enabled, persists one + # row per tool call with optional reverse_descriptor for + # /api/threads/{thread_id}/revert/{action_id}. Sits inside permission + # so denied calls aren't logged as completions. + action_log_mw: ActionLogMiddleware | None = None + if ( + flags.enable_action_log + and not flags.disable_new_agent_stack + and thread_id is not None + ): + try: + tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS} + action_log_mw = ActionLogMiddleware( + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + tool_definitions=tool_defs_by_name, + ) + except Exception: # pragma: no cover - defensive + logging.warning( + "ActionLogMiddleware init failed; running without it.", + exc_info=True, + ) + action_log_mw = None + + # Tier 2.2: per-thread busy mutex. + busy_mutex_mw: BusyMutexMiddleware | None = ( + BusyMutexMiddleware() + if flags.enable_busy_mutex and not flags.disable_new_agent_stack + else None + ) + + # Tier 3b: OpenTelemetry spans (model.call + tool.call). Lives just + # inside BusyMutex so it spans every retry/fallback attempt of the + # current turn but never wraps a queued/blocked turn. + otel_mw: OtelSpanMiddleware | None = ( + OtelSpanMiddleware() + if flags.enable_otel and not flags.disable_new_agent_stack + else None + ) + + # Tier 6: plugin entry-point loader. Off by default; opt-in via the + # ``SURFSENSE_ENABLE_PLUGIN_LOADER`` flag. The allowlist is read from + # the ``SURFSENSE_ALLOWED_PLUGINS`` env var (comma-separated). A future + # PR can wire it through ``global_llm_config.yaml``. + plugin_middlewares: list[Any] = [] + if flags.enable_plugin_loader and not flags.disable_new_agent_stack: + try: + allowed_names = load_allowed_plugin_names_from_env() + if allowed_names: + plugin_middlewares = load_plugin_middlewares( + PluginContext.build( + search_space_id=search_space_id, + user_id=user_id, + thread_visibility=visibility, + llm=llm, + ), + allowed_plugin_names=allowed_names, + ) + except Exception: # pragma: no cover - defensive + logging.warning( + "Plugin loader failed; continuing without plugins.", + exc_info=True, + ) + plugin_middlewares = [] + + # Tier 4.1: SkillsMiddleware. Loads built-in + space-authored skills + # via a CompositeBackend. Sources are layered: built-in first, space + # last, so a search-space-authored skill of the same name overrides + # the bundled one. + skills_mw: SkillsMiddleware | None = None + if flags.enable_skills and not flags.disable_new_agent_stack: + try: + skills_factory = build_skills_backend_factory( + search_space_id=search_space_id + if filesystem_mode == FilesystemMode.CLOUD + else None, + ) + skills_mw = SkillsMiddleware( + backend=skills_factory, + sources=default_skills_sources(), + ) + except Exception as exc: # pragma: no cover - defensive + logging.warning("SkillsMiddleware init failed; skipping: %s", exc) + skills_mw = None + + # Tier 2.5: LLM-driven tool selection for >30 tools. + selector_mw: LLMToolSelectorMiddleware | None = None + if ( + flags.enable_llm_tool_selector + and not flags.disable_new_agent_stack + and len(tools) > 30 + ): + try: + selector_mw = LLMToolSelectorMiddleware( + model="openai:gpt-4o-mini", + max_tools=12, + always_include=[ + name + for name in ("update_memory", "get_connected_accounts", "scrape_webpage") + if name in {t.name for t in tools} + ], + ) + except Exception: + logging.warning("LLMToolSelectorMiddleware init failed; skipping.") + selector_mw = None + deepagent_middleware = [ + # BusyMutex is OUTERMOST: it must wrap the entire stream so no + # other turn can sneak in while this one is mid-flight. + busy_mutex_mw, + # OTel spans sit just inside BusyMutex so each retry attempt + # gets its own model.call / tool.call span. + otel_mw, TodoListMiddleware(), _memory_middleware, AnonymousDocumentMiddleware( @@ -514,10 +931,40 @@ def _build_compiled_agent_blocking( ) if filesystem_mode == FilesystemMode.CLOUD else None, - SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]), - create_safe_summarization_middleware(llm, StateBackend), + # Tier 4.1: skill loader. Placed before SubAgentMiddleware so + # subagents inherit the same skill metadata (subagent specs reference + # the same source paths via `default_skills_sources()`). + skills_mw, + SubAgentMiddleware(backend=StateBackend, subagents=subagent_specs), + # Tier 2.5: tool selection (only when >30 tools and flag on). + selector_mw, + # Defensive caps, then prune, then summarize. + model_call_limit_mw, + tool_call_limit_mw, + context_edit_mw, + summarization_mw, + # Provider compatibility + retry chain — placed after prune/compact + # so retries happen on the already-trimmed payload. + noop_mw, + retry_mw, + fallback_mw, + # Tool-call repair must run after model emits but before + # permission / dedup / doom-loop interpret the calls. + repair_mw, + # Tier 2.1: deny/ask BEFORE the calls are forwarded to tool nodes. + permission_mw, + doom_loop_mw, + # Tier 5.2: action log sits inside permission so denied calls + # don't appear as completions, and outside dedup so each unique + # tool invocation gets its own row. + action_log_mw, PatchToolCallsMiddleware(), DedupHITLToolCallsMiddleware(agent_tools=list(tools)), + # Tier 6: plugin slot — sits just before AnthropicCache so plugin-side + # transforms see the final tool result and run before any caching + # heuristics. Multiple plugins in declared order; loader filtered by + # the admin allowlist already. + *plugin_middlewares, AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] deepagent_middleware = [m for m in deepagent_middleware if m is not None] diff --git a/surfsense_backend/app/agents/new_chat/errors.py b/surfsense_backend/app/agents/new_chat/errors.py new file mode 100644 index 000000000..b7bac4536 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/errors.py @@ -0,0 +1,95 @@ +""" +Typed error taxonomy for the SurfSense agent stack. + +Used by: +- :class:`RetryAfterMiddleware` (Tier 1.4) — its ``retry_on`` callable + consults the error code to decide whether a retry is appropriate. +- :class:`PermissionMiddleware` (Tier 2.1) — emits + ``code="permission_denied"`` errors when a deny rule trips. +- All tools — return :class:`StreamingError` payloads in + ``ToolMessage.additional_kwargs["error"]`` so the model and the + retry/permission layers share a contract. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +ErrorCode = Literal[ + "rate_limit", + "auth", + "tool_validation", + "tool_runtime", + "context_overflow", + "provider", + "permission_denied", + "doom_loop", + "busy", + "cancelled", +] + + +class StreamingError(BaseModel): + """Structured error payload attached to ``ToolMessage.additional_kwargs["error"]``. + + Tools and middleware emit this so retry, permission, and routing + layers can decide what to do without parsing free-form strings. + """ + + code: ErrorCode + retryable: bool = False + suggestion: str | None = None + correlation_id: str | None = None + detail: str | None = Field( + default=None, + description="Free-form additional context. Not surfaced to the model.", + ) + + class Config: + frozen = True + + +class RejectedError(Exception): + """Raised when the user rejects a permission ask without feedback. + + Caught by :class:`PermissionMiddleware`; the agent stops the current + tool fan-out and surfaces a user-facing rejection. + """ + + def __init__(self, *, tool: str | None = None, pattern: str | None = None) -> None: + super().__init__(f"Permission rejected for tool {tool!r}, pattern {pattern!r}") + self.tool = tool + self.pattern = pattern + + +class CorrectedError(Exception): + """Raised when the user rejects a permission ask *with* feedback. + + The :class:`PermissionMiddleware` translates the feedback into a + synthetic ``ToolMessage`` so the model sees the user's correction + and can retry the request differently. + """ + + def __init__(self, feedback: str, *, tool: str | None = None) -> None: + super().__init__(feedback) + self.feedback = feedback + self.tool = tool + + +class BusyError(Exception): + """Raised when a second prompt arrives while the same thread is mid-stream.""" + + def __init__(self, request_id: str | None = None) -> None: + super().__init__("Thread is busy with another request") + self.request_id = request_id + + +__all__ = [ + "BusyError", + "CorrectedError", + "ErrorCode", + "RejectedError", + "StreamingError", +] diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py new file mode 100644 index 000000000..ce0a3b3fa --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -0,0 +1,188 @@ +""" +Feature flags for the SurfSense new_chat agent stack. + +These flags control rollout of OpenCode-pattern middleware ported into +SurfSense. They follow a "default-OFF for risky things, default-ON for +safe upgrades, master kill-switch for everything new" model. + +All new middleware checks its flag at agent build time. If the master +kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new +middleware is disabled regardless of its individual flag. This gives +operators a single switch to revert to pre-port behavior. + +Examples +-------- + +Local development (recommended for trying everything except doom-loop / selector): + + SURFSENSE_ENABLE_CONTEXT_EDITING=true + SURFSENSE_ENABLE_COMPACTION_V2=true + SURFSENSE_ENABLE_RETRY_AFTER=true + SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true + SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy + SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships + SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false + +Master kill-switch (overrides everything else): + + SURFSENSE_DISABLE_NEW_AGENT_STACK=true +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +def _env_bool(name: str, default: bool) -> bool: + """Parse a boolean env var. Accepts ``1``/``true``/``yes``/``on`` (case-insensitive).""" + raw = os.environ.get(name) + if raw is None: + return default + return raw.strip().lower() in ("1", "true", "yes", "on") + + +@dataclass(frozen=True) +class AgentFeatureFlags: + """Resolved feature-flag state for one agent build. + + Constructed via :meth:`from_env`. The dataclass is frozen so it can be + safely shared across coroutines. + """ + + # Master kill-switch — when true, every flag below resolves to False + # regardless of its env value. Used for rapid rollback. + disable_new_agent_stack: bool = False + + # Tier 1 — Agent quality + enable_context_editing: bool = False + enable_compaction_v2: bool = False + enable_retry_after: bool = False + enable_model_fallback: bool = False + enable_model_call_limit: bool = False + enable_tool_call_limit: bool = False + enable_tool_call_repair: bool = False + enable_doom_loop: bool = False # Default OFF until UI handles permission='doom_loop' + + # Tier 2 — Safety + enable_permission: bool = False # Default OFF for first deploy + enable_busy_mutex: bool = False + enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost + + # Tier 4 — Skills + subagents + enable_skills: bool = False + enable_specialized_subagents: bool = False + enable_kb_planner_runnable: bool = False + + # Tier 5 — Snapshot / revert + enable_action_log: bool = False + enable_revert_route: bool = False # Backend ships before UI; route returns 503 until this flips + + # Tier 6 — Plugins + enable_plugin_loader: bool = False + + # Tier 3b — OTel (orthogonal: also requires OTEL_EXPORTER_OTLP_ENDPOINT) + enable_otel: bool = False + + @classmethod + def from_env(cls) -> AgentFeatureFlags: + """Read flags from environment. + + Master kill-switch is evaluated first; when set, all other flags + force to False. + """ + master_off = _env_bool("SURFSENSE_DISABLE_NEW_AGENT_STACK", False) + if master_off: + logger.info( + "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " + "middleware is forced OFF for this build." + ) + return cls(disable_new_agent_stack=True) + + return cls( + disable_new_agent_stack=False, + # Tier 1 + enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), + enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), + enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), + enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), + enable_model_call_limit=_env_bool("SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False), + enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), + enable_tool_call_repair=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False), + enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), + # Tier 2 + enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), + enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), + enable_llm_tool_selector=_env_bool("SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False), + # Tier 4 + enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), + enable_specialized_subagents=_env_bool( + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False + ), + enable_kb_planner_runnable=_env_bool( + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False + ), + # Tier 5 + enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), + enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + # Tier 6 + enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), + # Tier 3b + enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), + ) + + def any_new_middleware_enabled(self) -> bool: + """Return True if any new middleware flag is on.""" + if self.disable_new_agent_stack: + return False + return any( + ( + self.enable_context_editing, + self.enable_compaction_v2, + self.enable_retry_after, + self.enable_model_fallback, + self.enable_model_call_limit, + self.enable_tool_call_limit, + self.enable_tool_call_repair, + self.enable_doom_loop, + self.enable_permission, + self.enable_busy_mutex, + self.enable_llm_tool_selector, + self.enable_skills, + self.enable_specialized_subagents, + self.enable_kb_planner_runnable, + self.enable_action_log, + self.enable_revert_route, + self.enable_plugin_loader, + ) + ) + + +# Module-level cache. Read once at import time so the values are consistent +# across the process lifetime. Use ``reload_for_tests`` to reset in tests. +_FLAGS: AgentFeatureFlags | None = None + + +def get_flags() -> AgentFeatureFlags: + """Return the resolved feature-flag state, caching on first call.""" + global _FLAGS + if _FLAGS is None: + _FLAGS = AgentFeatureFlags.from_env() + return _FLAGS + + +def reload_for_tests() -> AgentFeatureFlags: + """Force a fresh read from env. Tests should call this after monkeypatching env.""" + global _FLAGS + _FLAGS = AgentFeatureFlags.from_env() + return _FLAGS + + +__all__ = [ + "AgentFeatureFlags", + "get_flags", + "reload_for_tests", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index e885d9e6b..094c102f8 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -1,11 +1,23 @@ """Middleware components for the SurfSense new chat agent.""" +from app.agents.new_chat.middleware.action_log import ActionLogMiddleware from app.agents.new_chat.middleware.anonymous_document import ( AnonymousDocumentMiddleware, ) +from app.agents.new_chat.middleware.busy_mutex import BusyMutexMiddleware +from app.agents.new_chat.middleware.compaction import ( + SurfSenseCompactionMiddleware, + create_surfsense_compaction_middleware, +) +from app.agents.new_chat.middleware.context_editing import ( + ClearToolUsesEdit, + SpillingContextEditingMiddleware, + SpillToBackendEdit, +) from app.agents.new_chat.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, ) +from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware from app.agents.new_chat.middleware.file_intent import ( FileIntentMiddleware, ) @@ -26,16 +38,46 @@ from app.agents.new_chat.middleware.knowledge_tree import ( from app.agents.new_chat.middleware.memory_injection import ( MemoryInjectionMiddleware, ) +from app.agents.new_chat.middleware.noop_injection import NoopInjectionMiddleware +from app.agents.new_chat.middleware.otel_span import OtelSpanMiddleware +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware +from app.agents.new_chat.middleware.skills_backends import ( + BuiltinSkillsBackend, + SearchSpaceSkillsBackend, + build_skills_backend_factory, + default_skills_sources, +) +from app.agents.new_chat.middleware.tool_call_repair import ( + ToolCallNameRepairMiddleware, +) __all__ = [ + "ActionLogMiddleware", "AnonymousDocumentMiddleware", + "BuiltinSkillsBackend", + "BusyMutexMiddleware", + "ClearToolUsesEdit", "DedupHITLToolCallsMiddleware", + "DoomLoopMiddleware", "FileIntentMiddleware", "KnowledgeBasePersistenceMiddleware", "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", "KnowledgeTreeMiddleware", "MemoryInjectionMiddleware", + "NoopInjectionMiddleware", + "OtelSpanMiddleware", + "PermissionMiddleware", + "RetryAfterMiddleware", + "SearchSpaceSkillsBackend", + "SpillToBackendEdit", + "SpillingContextEditingMiddleware", + "SurfSenseCompactionMiddleware", "SurfSenseFilesystemMiddleware", + "ToolCallNameRepairMiddleware", + "build_skills_backend_factory", "commit_staged_filesystem_state", + "create_surfsense_compaction_middleware", + "default_skills_sources", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/new_chat/middleware/action_log.py new file mode 100644 index 000000000..cf0b57fd4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/action_log.py @@ -0,0 +1,294 @@ +"""Append-only action-log middleware for the SurfSense agent. + +Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes +a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt +into reversibility by declaring a ``reverse`` callable on their +:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered +descriptor is persisted in ``reverse_descriptor`` for use by +``/api/threads/{thread_id}/revert/{action_id}``. + +Design points: + +* **Defensive.** Logging never blocks the agent. We catch every exception + on the DB write path and emit a warning; the tool's ``ToolMessage`` + result is always returned untouched. +* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) + + ``result_id`` + ``reverse_descriptor`` are stored. Tool output text + remains in the LangGraph checkpoint / spilled tool-output files. +* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)`` + with the parsed JSON result when the tool's content is a JSON object; + otherwise the raw text is passed. Exceptions in the reverse callable + are swallowed and logged — a failed descriptor render simply means the + action is NOT marked reversible. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage + +from app.agents.new_chat.feature_flags import get_flags +from app.agents.new_chat.tools.registry import ToolDefinition + +if TYPE_CHECKING: # pragma: no cover - type-only + from langchain.agents.middleware.types import ToolCallRequest + from langgraph.types import Command + + +logger = logging.getLogger(__name__) + + +# Cap for the persisted ``args`` JSON to avoid bloating the action log with +# accidentally-huge inputs. Values are truncated and a flag is set in the +# stored payload so consumers can detect truncation. +_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB + + +class ActionLogMiddleware(AgentMiddleware): + """Persist a row in :class:`AgentActionLog` after every tool call. + + Should be placed near the OUTERMOST end of the tool-call wrapping stack + so that it sees the *final* :class:`ToolMessage` after all retries, + permission checks, and dedup logic have run. In practice that means + placing it just inside :class:`PermissionMiddleware` and outside + :class:`DedupHITLToolCallsMiddleware`. + + The middleware is fully a no-op when: + + * the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set + (checked via :func:`get_flags`), + * the per-feature flag ``enable_action_log`` is off, or + * persistence raises (defensive: tool-call dispatch always succeeds). + + Args: + thread_id: The current chat thread's primary-key id. Required to + persist a row; if ``None`` the middleware silently no-ops. + search_space_id: Search-space id for cascade-on-delete safety. + user_id: UUID string of the user driving this turn (nullable in + anonymous mode). + tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition` + so the middleware can look up the tool's ``reverse`` callable. + When omitted, no actions are marked reversible. + """ + + tools = () + + def __init__( + self, + *, + thread_id: int | None, + search_space_id: int, + user_id: str | None, + tool_definitions: dict[str, ToolDefinition] | None = None, + ) -> None: + super().__init__() + self._thread_id = thread_id + self._search_space_id = search_space_id + self._user_id = user_id + self._tool_definitions = dict(tool_definitions or {}) + + def _enabled(self) -> bool: + flags = get_flags() + if flags.disable_new_agent_stack: + return False + return bool(flags.enable_action_log) and self._thread_id is not None + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[ + [ToolCallRequest], Awaitable[ToolMessage | Command[Any]] + ], + ) -> ToolMessage | Command[Any]: + if not self._enabled(): + return await handler(request) + + result: ToolMessage | Command[Any] + error_payload: dict[str, Any] | None = None + try: + result = await handler(request) + except Exception as exc: + # Persist the failure too so revert/audit can see it, then + # re-raise so downstream middleware (RetryAfter, etc.) handles it. + error_payload = {"type": type(exc).__name__, "message": str(exc)} + await self._record( + request=request, + result=None, + error_payload=error_payload, + ) + raise + + await self._record(request=request, result=result, error_payload=None) + return result + + async def _record( + self, + *, + request: ToolCallRequest, + result: ToolMessage | Command[Any] | None, + error_payload: dict[str, Any] | None, + ) -> None: + """Persist one ``agent_action_log`` row. Defensive: never raises.""" + try: + from app.db import AgentActionLog, shielded_async_session + + tool_name = _resolve_tool_name(request) + args_payload = _resolve_args_payload(request) + result_id = _resolve_result_id(result) + reverse_descriptor, reversible = self._render_reverse( + tool_name=tool_name, + args=_resolve_args_dict(request), + result=result, + ) + + row = AgentActionLog( + thread_id=self._thread_id, + user_id=self._user_id, + search_space_id=self._search_space_id, + turn_id=_resolve_turn_id(request), + message_id=_resolve_message_id(request), + tool_name=tool_name, + args=args_payload, + result_id=result_id, + reversible=reversible, + reverse_descriptor=reverse_descriptor, + error=error_payload, + ) + async with shielded_async_session() as session: + session.add(row) + await session.commit() + except Exception: + logger.warning( + "ActionLogMiddleware failed to persist action log row", + exc_info=True, + ) + + def _render_reverse( + self, + *, + tool_name: str, + args: dict[str, Any] | None, + result: ToolMessage | Command[Any] | None, + ) -> tuple[dict[str, Any] | None, bool]: + """Run the tool's ``reverse`` callable and return its descriptor. + + Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When + the tool has no ``reverse`` callable, or when the callable raises, + the action is marked non-reversible. + """ + if not result or not isinstance(result, ToolMessage): + return None, False + if args is None: + return None, False + tool_def = self._tool_definitions.get(tool_name) + if tool_def is None or tool_def.reverse is None: + return None, False + try: + parsed_result = _parse_tool_result_content(result) + descriptor = tool_def.reverse(args, parsed_result) + except Exception: + logger.warning( + "Reverse descriptor render failed for tool %s", + tool_name, + exc_info=True, + ) + return None, False + if not isinstance(descriptor, dict): + return None, False + return descriptor, True + + +# --------------------------------------------------------------------------- +# Resolution helpers — defensive against tool_call request shape variation. +# --------------------------------------------------------------------------- + + +def _resolve_tool_name(request: Any) -> str: + try: + tool = getattr(request, "tool", None) + if tool is not None: + name = getattr(tool, "name", None) + if isinstance(name, str) and name: + return name + call = getattr(request, "tool_call", None) or {} + if isinstance(call, dict): + name = call.get("name") + if isinstance(name, str) and name: + return name + except Exception: # pragma: no cover - defensive + pass + return "unknown" + + +def _resolve_args_dict(request: Any) -> dict[str, Any] | None: + try: + call = getattr(request, "tool_call", None) + if not isinstance(call, dict): + return None + args = call.get("args") + if isinstance(args, dict): + return args + return None + except Exception: # pragma: no cover - defensive + return None + + +def _resolve_args_payload(request: Any) -> dict[str, Any] | None: + """Return a JSON-serializable args dict, truncated if too big.""" + args = _resolve_args_dict(request) + if args is None: + return None + try: + encoded = json.dumps(args, default=str) + except Exception: + return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]} + if len(encoded) <= _MAX_ARGS_PERSIST_BYTES: + return args + return { + "_truncated": True, + "_size": len(encoded), + "_preview": encoded[:_MAX_ARGS_PERSIST_BYTES], + } + + +def _resolve_turn_id(request: Any) -> str | None: + try: + call = getattr(request, "tool_call", None) or {} + if isinstance(call, dict): + tid = call.get("id") + if isinstance(tid, str): + return tid + except Exception: # pragma: no cover + pass + return None + + +def _resolve_message_id(request: Any) -> str | None: + """Tool-call IDs serve as best-available message correlator at this layer.""" + return _resolve_turn_id(request) + + +def _resolve_result_id(result: Any) -> str | None: + if isinstance(result, ToolMessage): + msg_id = getattr(result, "id", None) + if isinstance(msg_id, str): + return msg_id + return None + + +def _parse_tool_result_content(result: ToolMessage) -> Any: + content = result.content + if isinstance(content, str): + try: + return json.loads(content) + except (json.JSONDecodeError, ValueError): + return content + return content + + +__all__ = ["ActionLogMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py new file mode 100644 index 000000000..1d95638d0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -0,0 +1,231 @@ +""" +BusyMutexMiddleware — per-thread asyncio lock + cancel token. + +Tier 2.2 in the OpenCode-port plan. Mirrors opencode's +``Stream.scoped(AbortController)`` pattern (single-process, in-memory +lock + cooperative cancellation). For multi-worker deployments a +distributed lock backend (Redis or PostgreSQL advisory locks) is a +phase-2 follow-up. + +What this provides: +- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``; + acquiring the lock during ``before_agent`` blocks any concurrent + prompt on the same thread until release. +- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running + tools can poll to abort cooperatively. The event is reset between + turns. Tools should check ``runtime.context.cancel_event.is_set()`` + in tight inner loops. +- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a + second turn arrives while the lock is held. + +Note: SurfSense's ``stream_new_chat`` is the call site that should +acquire/release. Wiring this as middleware means the contract is +explicit and the lock manager is shared with subagents that compile +their own ``create_agent`` runnables. +""" + +from __future__ import annotations + +import asyncio +import logging +import weakref +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, +) +from langgraph.config import get_config +from langgraph.runtime import Runtime + +from app.agents.new_chat.errors import BusyError + +logger = logging.getLogger(__name__) + + +class _ThreadLockManager: + """Process-local registry of per-thread asyncio locks + cancel events.""" + + def __init__(self) -> None: + self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = ( + weakref.WeakValueDictionary() + ) + self._cancel_events: dict[str, asyncio.Event] = {} + + def lock_for(self, thread_id: str) -> asyncio.Lock: + lock = self._locks.get(thread_id) + if lock is None: + lock = asyncio.Lock() + self._locks[thread_id] = lock + return lock + + def cancel_event(self, thread_id: str) -> asyncio.Event: + event = self._cancel_events.get(thread_id) + if event is None: + event = asyncio.Event() + self._cancel_events[thread_id] = event + return event + + def request_cancel(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + if event is None: + return False + event.set() + return True + + def reset(self, thread_id: str) -> None: + event = self._cancel_events.get(thread_id) + if event is not None: + event.clear() + + +# Module-level singleton — process-local but reused across all agent +# instances built in this process. Subagents created in nested +# ``create_agent`` calls also get this so locks are coherent. +manager = _ThreadLockManager() + + +def get_cancel_event(thread_id: str) -> asyncio.Event: + """Public accessor used by long-running tools to poll cancellation.""" + return manager.cancel_event(thread_id) + + +def request_cancel(thread_id: str) -> bool: + """Trip the cancel event for ``thread_id``. Returns True if found.""" + return manager.request_cancel(thread_id) + + +def reset_cancel(thread_id: str) -> None: + """Reset the cancel event for ``thread_id`` (called between turns).""" + manager.reset(thread_id) + + +class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Block concurrent prompts on the same thread. + + Acquires the thread's lock in ``abefore_agent`` and releases in + ``aafter_agent``. If the lock is held, raises :class:`BusyError` + so the caller can emit a ``surfsense.busy`` SSE event with the + in-flight request id. + + Args: + require_thread_id: When True, raise :class:`BusyError` if no + ``thread_id`` can be resolved from the active + ``RunnableConfig``. Default is False — we treat a missing + thread_id as "this turn has nothing to lock against" and + no-op the mutex. Set True only when you trust the call + site to always provide ``configurable.thread_id`` (e.g. + in production where ``stream_new_chat`` always does). + """ + + def __init__(self, *, require_thread_id: bool = False) -> None: + super().__init__() + self._require_thread_id = require_thread_id + self.tools = [] + # Per-call locks owned by this middleware. We track them as + # an instance attribute so ``aafter_agent`` knows which lock + # to release. + self._held_locks: dict[str, asyncio.Lock] = {} + + @staticmethod + def _thread_id(runtime: Runtime[ContextT]) -> str | None: + """Extract ``thread_id`` from the active LangGraph ``RunnableConfig``. + + ``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``. + The runnable config (where ``configurable.thread_id`` lives) must be + fetched via :func:`langgraph.config.get_config` from inside a node / + middleware. We fall back to ``getattr(runtime, "config", None)`` for + unit tests / legacy runtimes that synthesize a config-bearing stub. + """ + + def _from_dict(cfg: Any) -> str | None: + if not isinstance(cfg, dict): + return None + tid = (cfg.get("configurable") or {}).get("thread_id") + return str(tid) if tid is not None else None + + # Preferred path: real LangGraph runtime context. + try: + tid = _from_dict(get_config()) + except Exception: + tid = None + if tid is not None: + return tid + + # Fallback for tests and any runtime that surfaces a config dict + # directly on the runtime instance. + return _from_dict(getattr(runtime, "config", None)) + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState[Any], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + del state + thread_id = self._thread_id(runtime) + if thread_id is None: + if self._require_thread_id: + raise BusyError("no thread_id configured") + logger.debug( + "BusyMutexMiddleware: no thread_id resolved from RunnableConfig; " + "skipping per-thread lock for this turn." + ) + return None + + lock = manager.lock_for(thread_id) + if lock.locked(): + raise BusyError(request_id=thread_id) + await lock.acquire() + self._held_locks[thread_id] = lock + # Reset the cancel event so this turn starts fresh + reset_cancel(thread_id) + return None + + async def aafter_agent( # type: ignore[override] + self, + state: AgentState[Any], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + del state + thread_id = self._thread_id(runtime) + if thread_id is None: + return None + lock = self._held_locks.pop(thread_id, None) + if lock is not None and lock.locked(): + lock.release() + # Always clear cancel event between turns so a stale signal + # doesn't leak into the next request. + reset_cancel(thread_id) + return None + + # Provide sync no-ops because the middleware base class allows them + def before_agent( # type: ignore[override] + self, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + # Sync path: no asyncio.Lock to acquire. Best we can do is reject + # if anyone else is in flight. + thread_id = self._thread_id(runtime) + if thread_id is None: + if self._require_thread_id: + raise BusyError("no thread_id configured") + return None + lock = manager.lock_for(thread_id) + if lock.locked(): + raise BusyError(request_id=thread_id) + return None + + def after_agent( # type: ignore[override] + self, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + return None + + +__all__ = [ + "BusyMutexMiddleware", + "get_cancel_event", + "manager", + "request_cancel", + "reset_cancel", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/compaction.py b/surfsense_backend/app/agents/new_chat/middleware/compaction.py new file mode 100644 index 000000000..8b02089c9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/compaction.py @@ -0,0 +1,253 @@ +""" +SurfSense compaction middleware. + +Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware` +to add SurfSense-specific behavior: + +1. **Structured summary template** (OpenCode-style ``## Goal / Constraints / + Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``). +2. **Protect SurfSense-specific SystemMessages** so injected hints + (````, ````, ````, + ````, ````, ````, ````) + are *not* summarized away and are kept verbatim in the post-summary + message list. +3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string`` + (Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage + containing only tool_calls and no text, ``content`` can be ``None`` and + ``get_buffer_string`` crashes iterating over ``None``). This used to live in + ``safe_summarization.py``; folded in here. + +This replaces ``app.agents.new_chat.middleware.safe_summarization``. + +Tier 1.3 in the OpenCode-port plan. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from deepagents.middleware.summarization import ( + SummarizationMiddleware, + compute_summarization_defaults, +) +from langchain_core.messages import SystemMessage + +from app.observability import otel as ot + +if TYPE_CHECKING: + from deepagents.backends.protocol import BACKEND_TYPES + from langchain_core.language_models import BaseChatModel + from langchain_core.messages import AnyMessage + +logger = logging.getLogger(__name__) + +# OpenCode-faithful structured summary template. Mirrors +# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a +# module-level constant so unit tests can assert on its sections. +SURFSENSE_SUMMARY_PROMPT = """ +SurfSense Conversation Compaction Assistant + + + +Extract the most important context from the conversation history below into a structured summary that will replace the older messages. + + + +You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report. + +## Goal +What is the user's primary goal or request? State it in one or two sentences. + +## Constraints +What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)? + +## Progress +What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion. + +## Key Decisions +What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path. + +## Next Steps +What specific tasks remain to achieve the goal? Order them by dependency. + +## Critical Context +What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name). + +## Relevant Files +What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree. + + + +Messages to summarize: +{messages} + + +Respond ONLY with the structured summary. Do not include any text before or after. +""" + +# SystemMessage prefixes that must NOT be summarized away. They are +# re-injected on every turn by the corresponding middleware, but the +# compaction step happens *before* re-injection in some paths, so we +# must preserve them verbatim across the cutoff. +PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = ( + "", # KnowledgePriorityMiddleware + "", # KnowledgeTreeMiddleware + "", # FileIntentMiddleware + "", # MemoryInjectionMiddleware + "", # MemoryInjectionMiddleware + "", # MemoryInjectionMiddleware + "", # MemoryInjectionMiddleware +) + + +def _is_protected_system_message(msg: AnyMessage) -> bool: + """Return True if ``msg`` is a SystemMessage we must not summarize.""" + if not isinstance(msg, SystemMessage): + return False + content = msg.content + if not isinstance(content, str): + return False + stripped = content.lstrip() + return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES) + + +def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: + """Return ``msg`` with ``content=None`` coerced to ``""``. + + Folds in the historical defense from ``safe_summarization.py`` — + ``get_buffer_string`` reads ``m.text`` which iterates ``self.content``, + so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only + AIMessage) explodes. We return a copy with empty string content so + downstream consumers see an empty body without mutating the original. + """ + if getattr(msg, "content", "not-missing") is not None: + return msg + try: + return msg.model_copy(update={"content": ""}) + except AttributeError: + import copy + + new_msg = copy.copy(msg) + try: + new_msg.content = "" + except Exception: + logger.debug( + "Could not sanitize content=None on message of type %s", + type(msg).__name__, + ) + return msg + return new_msg + + +class SurfSenseCompactionMiddleware(SummarizationMiddleware): + """SummarizationMiddleware tuned for SurfSense. + + Notes + ----- + - Overrides :meth:`_partition_messages` so protected SystemMessages + survive into the ``preserved_messages`` half regardless of cutoff. + - Overrides :meth:`_filter_summary_messages` so the buffer-string path + never iterates ``None`` content. + - Inherits everything else (auto-trigger, backend offload, + ``_summarization_event`` plumbing, ``ContextOverflowError`` fallback). + """ + + def _partition_messages( # type: ignore[override] + self, + conversation_messages: list[AnyMessage], + cutoff_index: int, + ) -> tuple[list[AnyMessage], list[AnyMessage]]: + """Split messages but always preserve SurfSense protected SystemMessages. + + Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy + (``opencode/packages/opencode/src/session/compaction.ts``): some + message types are always kept verbatim because they are part of the + agent's working contract, not transient output. + + Also opens a ``compaction.run`` OTel span (no-op when OTel is off) + so dashboards can count compaction events and message-volume + without having to instrument upstream callers. + """ + # Opening a span here is appropriate because partitioning is the + # first call SummarizationMiddleware makes when it has decided to + # summarize; we record the volume and then close as a normal span. + with ot.compaction_span( + reason="auto", + messages_in=len(conversation_messages), + extra={"compaction.cutoff_index": int(cutoff_index)}, + ): + messages_to_summarize, preserved_messages = ( + super()._partition_messages(conversation_messages, cutoff_index) + ) + + protected: list[AnyMessage] = [] + kept_for_summary: list[AnyMessage] = [] + for msg in messages_to_summarize: + if _is_protected_system_message(msg): + protected.append(msg) + else: + kept_for_summary.append(msg) + + # Place protected blocks at the *front* of preserved_messages so + # they keep their original ordering relative to the summary + # HumanMessage that precedes the rest of the preserved tail. + return kept_for_summary, [*protected, *preserved_messages] + + def _filter_summary_messages( # type: ignore[override] + self, messages: list[AnyMessage] + ) -> list[AnyMessage]: + """Filter previous summaries AND sanitize ``content=None``. + + Folds the ``safe_summarization.py`` defense in: when the buffer + builder iterates ``m.text`` over ``None`` it explodes; sanitizing + here covers both the sync and async offload paths. + """ + filtered = super()._filter_summary_messages(messages) + return [_sanitize_message_content(m) for m in filtered] + + +def create_surfsense_compaction_middleware( + model: BaseChatModel, + backend: BACKEND_TYPES, + *, + summary_prompt: str | None = None, + history_path_prefix: str = "/conversation_history", + **overrides: Any, +) -> SurfSenseCompactionMiddleware: + """Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults. + + Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings`` + via :func:`deepagents.middleware.summarization.compute_summarization_defaults` + so callers get the same behavior as ``create_summarization_middleware`` + plus our overrides. + + Args: + model: Chat model to call for summary generation. + backend: Backend instance or factory for offloading conversation history. + summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`. + history_path_prefix: Path prefix for offloaded conversation history. + **overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`. + """ + defaults = compute_summarization_defaults(model) + return SurfSenseCompactionMiddleware( + model=model, + backend=backend, + trigger=overrides.pop("trigger", defaults["trigger"]), + keep=overrides.pop("keep", defaults["keep"]), + trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None), + truncate_args_settings=overrides.pop( + "truncate_args_settings", defaults["truncate_args_settings"] + ), + summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT, + history_path_prefix=history_path_prefix, + **overrides, + ) + + +__all__ = [ + "PROTECTED_SYSTEM_PREFIXES", + "SURFSENSE_SUMMARY_PROMPT", + "SurfSenseCompactionMiddleware", + "create_surfsense_compaction_middleware", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py new file mode 100644 index 000000000..93ceab8ee --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py @@ -0,0 +1,349 @@ +""" +SpillToBackendEdit + SpillingContextEditingMiddleware. + +Mirrors OpenCode's spill-to-disk behavior in +``opencode/packages/opencode/src/tool/truncate.ts``. Before +``ClearToolUsesEdit`` rewrites old ``ToolMessage.content`` to a placeholder, +we capture the full original content and write it to the runtime backend +under ``/tool_outputs/{thread_id}/{message_id}.txt``. The placeholder is +upgraded to ``"[cleared — full output at /tool_outputs/.../{id}.txt; ask the +explore subagent to read it]"`` so the agent can recover it on demand. + +Tier 1.2 in the OpenCode-port plan. + +Why this is a middleware subclass instead of a plain ``ContextEdit``: +``ContextEdit.apply`` is sync, but writing to the backend is async. We +capture the spill payloads inside ``apply`` and flush them via +``await backend.aupload_files(...)`` from ``awrap_model_call`` *before* +delegating to the handler, so the explore subagent can always read what +the placeholder advertises. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Awaitable, Callable, Sequence +from copy import deepcopy +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware.context_editing import ( + ClearToolUsesEdit, + ContextEdit, + ContextEditingMiddleware, + TokenCounter, +) +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + ToolMessage, +) +from langchain_core.messages.utils import count_tokens_approximately +from langgraph.config import get_config + +if TYPE_CHECKING: + from deepagents.backends.protocol import BackendProtocol + from langchain.agents.middleware.types import ( + ModelRequest, + ModelResponse, + ) + +logger = logging.getLogger(__name__) + +DEFAULT_SPILL_PREFIX = "/tool_outputs" + + +def _build_spill_placeholder(spill_path: str) -> str: + """Build the user-facing placeholder text shown to the model.""" + return ( + f"[cleared — full output at {spill_path}; " + f"ask the explore subagent to read it]" + ) + + +def _get_thread_id_or_session() -> str: + """Best-effort thread_id discovery for the spill path. + + Falls back to a process-stable string if no LangGraph config is + available (e.g. unit tests). The exact value doesn't matter as long + as it's stable within one stream so the placeholder paths line up + with the actual upload path. + """ + try: + config = get_config() + thread_id = config.get("configurable", {}).get("thread_id") + if thread_id is not None: + return str(thread_id) + except RuntimeError: + pass + return "no_thread" + + +@dataclass(slots=True) +class SpillToBackendEdit(ContextEdit): + """Capture-and-replace context edit that spills full tool output to the backend. + + Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude + semantics) **and** records the original ``ToolMessage.content`` in + :attr:`pending_spills` so the wrapping middleware can flush them + before the model call. + + Args: + trigger: Token threshold above which the edit fires. + clear_at_least: Minimum number of tokens to reclaim (best effort). + keep: Number of most-recent ``ToolMessage`` instances to leave + untouched. + exclude_tools: Names of tools whose output is NOT spilled. + clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls`` + args when their pair is cleared. + path_prefix: Path under the backend where spills are written. + Default ``"/tool_outputs"``. + """ + + trigger: int = 100_000 + clear_at_least: int = 0 + keep: int = 3 + clear_tool_inputs: bool = False + exclude_tools: Sequence[str] = () + path_prefix: str = DEFAULT_SPILL_PREFIX + + pending_spills: list[tuple[str, bytes]] = field(default_factory=list) + _lock: threading.Lock = field(default_factory=threading.Lock) + + def drain_pending(self) -> list[tuple[str, bytes]]: + """Return and clear the pending-spill list atomically.""" + with self._lock: + out = list(self.pending_spills) + self.pending_spills.clear() + return out + + def apply( + self, + messages: list[AnyMessage], + *, + count_tokens: TokenCounter, + ) -> None: + """Mirror ``ClearToolUsesEdit.apply`` but capture originals first.""" + tokens = count_tokens(messages) + if tokens <= self.trigger: + return + + candidates = [ + (idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage) + ] + if self.keep >= len(candidates): + return + if self.keep: + candidates = candidates[: -self.keep] + + thread_id = _get_thread_id_or_session() + excluded_tools = set(self.exclude_tools) + + for idx, tool_message in candidates: + if tool_message.response_metadata.get("context_editing", {}).get("cleared"): + continue + + ai_message = next( + (m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), + None, + ) + if ai_message is None: + continue + + tool_call = next( + ( + call + for call in ai_message.tool_calls + if call.get("id") == tool_message.tool_call_id + ), + None, + ) + if tool_call is None: + continue + + tool_name = tool_message.name or tool_call["name"] + if tool_name in excluded_tools: + continue + + message_id = tool_message.id or tool_message.tool_call_id or "unknown" + spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt" + + original = tool_message.content + payload = self._encode_payload(original) + with self._lock: + self.pending_spills.append((spill_path, payload)) + + messages[idx] = tool_message.model_copy( + update={ + "artifact": None, + "content": _build_spill_placeholder(spill_path), + "response_metadata": { + **tool_message.response_metadata, + "context_editing": { + "cleared": True, + "strategy": "spill_to_backend", + "spill_path": spill_path, + }, + }, + } + ) + + if self.clear_tool_inputs: + ai_idx = messages.index(ai_message) + messages[ai_idx] = self._clear_input_args( + ai_message, tool_message.tool_call_id or "" + ) + + if self.clear_at_least > 0: + new_token_count = count_tokens(messages) + cleared_tokens = max(0, tokens - new_token_count) + if cleared_tokens >= self.clear_at_least: + break + + @staticmethod + def _encode_payload(content: Any) -> bytes: + """Serialize ``ToolMessage.content`` to bytes for upload.""" + if isinstance(content, bytes): + return content + if isinstance(content, str): + return content.encode("utf-8") + try: + import json + + return json.dumps(content, default=str).encode("utf-8") + except Exception: + return str(content).encode("utf-8") + + @staticmethod + def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage: + updated_tool_calls: list[dict[str, Any]] = [] + cleared_any = False + for tool_call in message.tool_calls: + updated = dict(tool_call) + if updated.get("id") == tool_call_id: + updated["args"] = {} + cleared_any = True + updated_tool_calls.append(updated) + + metadata = dict(getattr(message, "response_metadata", {})) + if cleared_any: + ctx = dict(metadata.get("context_editing", {})) + ids = set(ctx.get("cleared_tool_inputs", [])) + ids.add(tool_call_id) + ctx["cleared_tool_inputs"] = sorted(ids) + metadata["context_editing"] = ctx + return message.model_copy( + update={ + "tool_calls": updated_tool_calls, + "response_metadata": metadata, + } + ) + + +BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol" + + +class SpillingContextEditingMiddleware(ContextEditingMiddleware): + """:class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes. + + Runs the configured edits as the parent does, then flushes any + pending spills via the supplied backend resolver before delegating + to the model handler. Spill failures are logged but never abort the + model call — the placeholder text is already in the message, so the + worst case is the agent gets a placeholder it cannot follow up on. + """ + + def __init__( + self, + *, + edits: Sequence[ContextEdit], + backend_resolver: BackendResolver | None = None, + token_count_method: str = "approximate", + ) -> None: + super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type] + self._backend_resolver = backend_resolver + + def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None: + if self._backend_resolver is None: + return None + if callable(self._backend_resolver): + try: + from langchain.tools import ToolRuntime + + tool_runtime = ToolRuntime( + state=getattr(request, "state", {}), + context=getattr(request.runtime, "context", None), + stream_writer=getattr(request.runtime, "stream_writer", None), + store=getattr(request.runtime, "store", None), + config=getattr(request.runtime, "config", None) or {}, + tool_call_id=None, + ) + return self._backend_resolver(tool_runtime) + except Exception: + logger.exception("Failed to resolve spill backend") + return None + return self._backend_resolver # type: ignore[return-value] + + def _collect_pending(self) -> list[tuple[str, bytes]]: + out: list[tuple[str, bytes]] = [] + for edit in self.edits: + if isinstance(edit, SpillToBackendEdit): + out.extend(edit.drain_pending()) + return out + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> Any: + if not request.messages: + return await handler(request) + + if self.token_count_method == "approximate": + + def count_tokens(messages: Sequence[BaseMessage]) -> int: + return count_tokens_approximately(messages) + + else: + system_msg = [request.system_message] if request.system_message else [] + + def count_tokens(messages: Sequence[BaseMessage]) -> int: + return request.model.get_num_tokens_from_messages( + system_msg + list(messages), request.tools + ) + + edited_messages = deepcopy(list(request.messages)) + for edit in self.edits: + edit.apply(edited_messages, count_tokens=count_tokens) + + pending = self._collect_pending() + if pending: + backend = self._resolve_backend(request) + if backend is not None: + try: + await backend.aupload_files(pending) + except Exception: + logger.exception( + "Spill-to-backend upload failed (%d files); placeholders " + "remain in messages but content is unrecoverable", + len(pending), + ) + else: + logger.warning( + "SpillToBackendEdit produced %d pending spills but no backend " + "resolver was configured; content is unrecoverable", + len(pending), + ) + + return await handler(request.override(messages=edited_messages)) + + +__all__ = [ + "DEFAULT_SPILL_PREFIX", + "ClearToolUsesEdit", + "SpillToBackendEdit", + "SpillingContextEditingMiddleware", + "_build_spill_placeholder", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index 61494ff1a..3aff524fe 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -2,17 +2,28 @@ When the LLM emits multiple calls to the same HITL tool with the same primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``), -only the first call is kept. Non-HITL tools are never touched. +only the first call is kept. Non-HITL tools are never touched. This runs in the ``after_model`` hook — **before** any tool executes — so the duplicate call is stripped from the AIMessage that gets checkpointed. That means it is also safe across LangGraph ``interrupt()`` boundaries: the removed call will never appear on graph resume. + +Dedup-key resolution order (Tier 2.3 / cleanup in the OpenCode-port plan): + +1. :class:`ToolDefinition.dedup_key` — callable provided by the registry + entry. This is the canonical mechanism after the cleanup-tier removal + of the legacy ``PRIMARY_ARG`` map. +2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name; + used by MCP / Composio tools whose schemas the registry doesn't see. + +A tool with no resolver from either path simply opts out of dedup. """ from __future__ import annotations import logging +from collections.abc import Callable from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState @@ -20,81 +31,84 @@ from langgraph.runtime import Runtime logger = logging.getLogger(__name__) -_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { - # Gmail - "send_gmail_email": "subject", - "create_gmail_draft": "subject", - "update_gmail_draft": "draft_subject_or_id", - "trash_gmail_email": "email_subject_or_id", - # Google Calendar - "create_calendar_event": "title", - "update_calendar_event": "event_title_or_id", - "delete_calendar_event": "event_title_or_id", - # Google Drive - "create_google_drive_file": "file_name", - "delete_google_drive_file": "file_name", - # OneDrive - "create_onedrive_file": "file_name", - "delete_onedrive_file": "file_name", - # Dropbox - "create_dropbox_file": "file_name", - "delete_dropbox_file": "file_name", - # Notion - "create_notion_page": "title", - "update_notion_page": "page_title", - "delete_notion_page": "page_title", - # Linear - "create_linear_issue": "title", - "update_linear_issue": "issue_ref", - "delete_linear_issue": "issue_ref", - # Jira - "create_jira_issue": "summary", - "update_jira_issue": "issue_title_or_key", - "delete_jira_issue": "issue_title_or_key", - # Confluence - "create_confluence_page": "title", - "update_confluence_page": "page_title_or_id", - "delete_confluence_page": "page_title_or_id", -} +# Resolver type — given the tool ``args`` dict returns a stable +# string used to dedupe consecutive calls. ``None`` means no dedup. +DedupResolver = Callable[[dict[str, Any]], str] + + +def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver: + """Adapt a string-arg name into a :data:`DedupResolver`. + + Convenience helper used by registry entries that just want to dedupe + on a single arg's lowercased value (the most common case for native + HITL tools like ``send_gmail_email`` keyed on ``subject``). + + Example:: + + ToolDefinition( + name="send_gmail_email", + ..., + dedup_key=wrap_dedup_key_by_arg_name("subject"), + ) + """ + + def _resolver(args: dict[str, Any]) -> str: + return str(args.get(arg_name, "")).lower() + + return _resolver + + +# Backwards-compatible alias for code that imported the original +# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. +_wrap_string_key = wrap_dedup_key_by_arg_name class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] """Remove duplicate HITL tool calls from a single LLM response. - Only the **first** occurrence of each (tool-name, primary-arg-value) + Only the **first** occurrence of each ``(tool-name, dedup_key)`` pair is kept; subsequent duplicates are silently dropped. - The dedup map is built from two sources: + The dedup-resolver map is built from two sources, in priority order: - 1. A comprehensive list of native HITL tools (hardcoded above). - 2. Any ``StructuredTool`` instances passed via *agent_tools* whose - ``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``. - This is how MCP tools automatically get dedup support. + 1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's + ``ToolDefinition.dedup_key`` (Tier 2.3). Receives the args dict + and returns a string signature. This is the canonical mechanism + after the cleanup-tier removal of the legacy ``PRIMARY_ARG`` map. + 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg + name; primarily used by MCP / Composio tools. """ tools = () def __init__(self, *, agent_tools: list[Any] | None = None) -> None: - self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS) + self._resolvers: dict[str, DedupResolver] = {} + for t in agent_tools or []: meta = getattr(t, "metadata", None) or {} + callable_key = meta.get("dedup_key") + if callable(callable_key): + self._resolvers[t.name] = callable_key + continue if meta.get("hitl") and meta.get("hitl_dedup_key"): - self._dedup_keys[t.name] = meta["hitl_dedup_key"] + self._resolvers[t.name] = wrap_dedup_key_by_arg_name( + meta["hitl_dedup_key"] + ) def after_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state, self._dedup_keys) + return self._dedup(state, self._resolvers) async def aafter_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state, self._dedup_keys) + return self._dedup(state, self._resolvers) @staticmethod def _dedup( state: AgentState, - dedup_keys: dict[str, str], # type: ignore[type-arg] + resolvers: dict[str, DedupResolver], ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: @@ -110,9 +124,16 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] for tc in tool_calls: name = tc.get("name", "") - dedup_key_arg = dedup_keys.get(name) - if dedup_key_arg is not None: - arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower() + resolver = resolvers.get(name) + if resolver is not None: + try: + arg_val = resolver(tc.get("args", {}) or {}) + except Exception: + logger.exception( + "Dedup resolver for tool %s raised; keeping call", name + ) + deduped.append(tc) + continue key = (name, arg_val) if key in seen: logger.info( diff --git a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py new file mode 100644 index 000000000..49ac7dfa8 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py @@ -0,0 +1,228 @@ +""" +DoomLoopMiddleware — pattern-based detector for repeated identical tool calls. + +Mirrors ``opencode/packages/opencode/src/session/processor.ts`` doom-loop +behavior. When the same tool with the same arguments is called N times +in a row, the agent has likely entered an infinite loop. We surface this +to the user as an interrupt with ``permission="doom_loop"`` so the UI +can render an "Are you stuck? Continue / cancel?" affordance. + +Tier 1.11 in the OpenCode-port plan. + +This ships **OFF by default** until the frontend explicitly handles +``context.permission == "doom_loop"`` interrupts (the plan flips +``SURFSENSE_ENABLE_DOOM_LOOP=true`` once the UI is ready). + +Wire format: uses SurfSense's existing ``interrupt()`` payload shape +(see ``app/agents/new_chat/tools/hitl.py``): + + { + "type": "permission_ask", + "action": {"tool": , "params": }, + "context": {"permission": "doom_loop", "recent_signatures": [...]}, + } + +so the frontend that already handles HITL prompts can render this with +no changes beyond a string check. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from collections import deque +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, +) +from langchain_core.messages import AIMessage +from langgraph.config import get_config +from langgraph.runtime import Runtime +from langgraph.types import interrupt + +from app.observability import otel as ot + +logger = logging.getLogger(__name__) + + +def _signature(name: str, args: Any) -> str: + """Hash a tool call ``(name, args)`` to a short signature.""" + try: + canonical = json.dumps(args, sort_keys=True, default=str) + except (TypeError, ValueError): + canonical = repr(args) + digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest() + return digest[:16] + + +class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Detect repeated identical tool calls and prompt the user. + + Tracks a sliding window of the most-recent ``threshold`` tool-call + signatures across the live request. When all entries match, raise + a SurfSense-style HITL interrupt with ``permission="doom_loop"``. + + Args: + threshold: How many consecutive identical signatures count as a + doom loop. Default 3 (opencode parity). + """ + + def __init__(self, *, threshold: int = 3) -> None: + super().__init__() + if threshold < 2: + raise ValueError("DoomLoopMiddleware threshold must be >= 2") + self._threshold = threshold + self.tools = [] + # Per-thread sliding windows. We can't put this in graph state + # without state-schema gymnastics; for one process-lifetime it's + # fine to keep an in-memory map keyed by thread_id. + self._windows: dict[str, deque[str]] = {} + + @staticmethod + def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str: + """Resolve the thread id for sliding-window keying. + + Prefer LangGraph's ``get_config()`` (the only way to read + ``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry + a ``config`` attribute). Fall back to ``runtime.config`` for unit + tests that synthesize a config-bearing stub. Default + ``"no_thread"`` is intentionally only used when both lookups fail + — it would collapse all threads into one window so we keep the + debug log loud. + """ + + def _from_dict(cfg: Any) -> str | None: + if not isinstance(cfg, dict): + return None + tid = (cfg.get("configurable") or {}).get("thread_id") + return str(tid) if tid is not None else None + + try: + tid = _from_dict(get_config()) + except Exception: + tid = None + if tid is not None: + return tid + + tid = _from_dict(getattr(runtime, "config", None)) + if tid is not None: + return tid + + logger.debug( + "DoomLoopMiddleware: no thread_id resolved from RunnableConfig; " + "falling back to shared 'no_thread' window." + ) + return "no_thread" + + def _window(self, thread_id: str) -> deque[str]: + win = self._windows.get(thread_id) + if win is None: + win = deque(maxlen=self._threshold) + self._windows[thread_id] = win + return win + + def _detect( + self, message: AIMessage, runtime: Runtime[ContextT] + ) -> tuple[bool, list[str], dict[str, Any] | None]: + if not message.tool_calls: + return False, [], None + + thread_id = self._thread_id_from_runtime(runtime) + window = self._window(thread_id) + + triggered_call: dict[str, Any] | None = None + for call in message.tool_calls: + name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None) + args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {}) + if not isinstance(name, str): + continue + sig = _signature(name, args) + window.append(sig) + if ( + len(window) >= self._threshold + and len(set(window)) == 1 + ): + triggered_call = {"name": name, "params": args or {}} + break + + if triggered_call is None: + return False, list(window), None + return True, list(window), triggered_call + + def after_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + messages = state.get("messages") or [] + if not messages: + return None + last = messages[-1] + if not isinstance(last, AIMessage): + return None + + triggered, signatures, action = self._detect(last, runtime) + if not triggered: + return None + + logger.warning( + "Doom loop detected: tool %s called %d times in a row (sig=%s)", + action["name"] if action else "", + self._threshold, + signatures[-1] if signatures else "", + ) + + # Tier 3b: interrupt.raised span with permission=doom_loop attribute + # so dashboards can break out doom-loop interrupts from regular + # permission asks via the ``interrupt.permission`` attribute. + with ot.interrupt_span( + interrupt_type="permission_ask", + extra={ + "interrupt.permission": "doom_loop", + "interrupt.threshold": self._threshold, + "interrupt.tool": (action or {}).get("tool", ""), + }, + ): + decision = interrupt( + { + "type": "permission_ask", + "action": action or {"tool": "", "params": {}}, + "context": { + "permission": "doom_loop", + "recent_signatures": signatures, + "threshold": self._threshold, + }, + } + ) + + # Reset window so the next decision (continue/cancel) starts fresh. + thread_id = self._thread_id_from_runtime(runtime) + self._windows.pop(thread_id, None) + + # Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."} + # If the user cancelled, jump to end. Otherwise return ``None`` so the + # tool call proceeds. The frontend's exact reply names may differ — + # we tolerate any shape that contains a string with "reject"/"cancel". + if isinstance(decision, dict): + kind = str(decision.get("decision_type") or decision.get("type") or "").lower() + if "reject" in kind or "cancel" in kind: + return {"jump_to": "end"} + return None + + async def aafter_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + return self.after_model(state, runtime) + + +__all__ = [ + "DoomLoopMiddleware", + "_signature", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index edd8c7af1..f39870df6 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -31,14 +31,17 @@ from collections.abc import Sequence from datetime import UTC, datetime from typing import Any +from langchain.agents import create_agent from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.runnables import Runnable from langgraph.runtime import Runtime from litellm import token_counter from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState from app.agents.new_chat.path_resolver import ( @@ -589,6 +592,53 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] self.available_document_types = available_document_types self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] + # Tier 4.2: build the kb-planner private Runnable ONCE here so we + # don't pay the create_agent compile cost (50–200ms) on every turn. + # Disabled by default behind ``enable_kb_planner_runnable``; when off + # the planner falls back to the legacy ``self.llm.ainvoke`` path. + self._planner: Runnable | None = None + self._planner_compile_failed = False + + def _build_kb_planner_runnable(self) -> Runnable | None: + """Compile the kb-planner private :class:`Runnable` once. + + Returns ``None`` when the feature flag is disabled, when the LLM is + unavailable, or when ``create_agent`` raises (we fall back to the + legacy ``self.llm.ainvoke`` path in that case). Compilation happens + lazily on first call, then memoized via ``self._planner``. + + The compiled agent is constructed without tools — the planner's + contract is "answer with structured JSON" — but with ``RetryAfter`` + + the OpenCode-port retry/limit middleware so it shares the parent + agent's resilience guarantees. + """ + if self._planner is not None or self._planner_compile_failed: + return self._planner + if self.llm is None: + return None + flags = get_flags() + if ( + not flags.enable_kb_planner_runnable + or flags.disable_new_agent_stack + ): + return None + + from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware + + try: + self._planner = create_agent( + self.llm, + tools=[], + middleware=[RetryAfterMiddleware(max_retries=2)], + ) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb-planner Runnable compile failed; falling back to llm.ainvoke: %s", + exc, + ) + self._planner_compile_failed = True + self._planner = None + return self._planner async def _plan_search_inputs( self, @@ -611,11 +661,32 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] loop = asyncio.get_running_loop() t0 = loop.time() + # Tier 4.2: prefer the compiled-once planner Runnable when enabled; + # otherwise fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` + # tag is preserved on both paths so ``_stream_agent_events`` still + # suppresses the planner's intermediate events from the UI. + planner = self._build_kb_planner_runnable() try: - response = await self.llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal"]}, - ) + if planner is not None: + planner_state = await planner.ainvoke( + {"messages": [HumanMessage(content=prompt)]}, + config={"tags": ["surfsense:internal"]}, + ) + response_messages = ( + planner_state.get("messages", []) + if isinstance(planner_state, dict) + else [] + ) + response = ( + response_messages[-1] + if response_messages + else AIMessage(content="") + ) + else: + response = await self.llm.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal"]}, + ) plan = _parse_kb_search_plan_response(_extract_text_from_message(response)) optimized_query = ( re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text diff --git a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py new file mode 100644 index 000000000..f16084892 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py @@ -0,0 +1,133 @@ +""" +``_noop`` provider-compatibility tool + injection middleware. + +OpenCode injects a ``_noop`` tool for LiteLLM/Bedrock/Copilot when the +model call has empty tools but message history includes prior +``tool_calls`` — some providers 400 in that shape (see +``opencode/packages/opencode/src/session/llm.ts:209-228``). SurfSense uses +LiteLLM, and the compaction summarize call (no tools, history full of +tool calls) hits this. Tier 1.5 in the OpenCode-port plan. + +Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks +if the request has zero tools but the last AI message in history includes +``tool_calls``. If yes, it injects the ``_noop`` tool only — never globally, +mirroring opencode's gating exactly. The :func:`noop_tool` returns empty +content when called (which it should never be in practice). +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.messages import AIMessage +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + +NOOP_TOOL_NAME = "_noop" +NOOP_TOOL_DESCRIPTION = ( + "Do not call this tool. It exists only for API compatibility." +) + + +@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION) +def noop_tool() -> str: + """Return empty content. Never expected to be called.""" + return "" + + +# Provider markers that benefit from ``_noop`` injection. These match +# opencode's gating list. We also accept any string containing one of +# these substrings (so e.g. ``litellm`` matches ``ChatLiteLLM``). +_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = ( + "litellm", + "bedrock", + "copilot", +) + + +def _provider_needs_noop(model: Any) -> bool: + """Heuristic: does this model's provider need the _noop injection?""" + try: + ls_params = model._get_ls_params() + provider = str(ls_params.get("ls_provider", "")).lower() + except Exception: + provider = "" + + if not provider: + cls_name = type(model).__name__.lower() + provider = cls_name + + return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS) + + +def _last_ai_has_tool_calls(messages: list[Any]) -> bool: + for msg in reversed(messages): + if isinstance(msg, AIMessage): + return bool(msg.tool_calls) + return False + + +class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Inject the ``_noop`` tool only when the provider would otherwise 400. + + The check fires per model call, not at agent build time, because the + summarization path generates a no-tool subcall at runtime. The + extra tool is appended to ``request.tools`` as an instance — the + actual ``langchain_core.tools.BaseTool`` is bound on every call site + that creates the agent. + """ + + def __init__(self, *, noop_tool_instance: Any | None = None) -> None: + super().__init__() + self._noop_tool = noop_tool_instance or noop_tool + self.tools = [] + + def _should_inject(self, request: ModelRequest[ContextT]) -> bool: + if request.tools: + return False + if not _last_ai_has_tool_calls(request.messages): + return False + return _provider_needs_noop(request.model) + + def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: + return request.override(tools=[self._noop_tool]) + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> Any: + if self._should_inject(request): + logger.debug("Injecting _noop tool for provider compatibility") + return handler(self._augmented(request)) + return handler(request) + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> Any: + if self._should_inject(request): + logger.debug("Injecting _noop tool for provider compatibility") + return await handler(self._augmented(request)) + return await handler(request) + + +__all__ = [ + "NOOP_TOOL_DESCRIPTION", + "NOOP_TOOL_NAME", + "NoopInjectionMiddleware", + "_provider_needs_noop", + "noop_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py new file mode 100644 index 000000000..5585cf7a2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py @@ -0,0 +1,202 @@ +""" +OpenTelemetry span middleware for the SurfSense ``new_chat`` agent. + +Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool +executions) with OTel spans, attaching low-cardinality span names and +high-cardinality identifiers as attributes (per the Tier 3b plan). + +This middleware is intentionally a thin adapter over +:mod:`app.observability.otel`; when OTel is not configured all spans +collapse to no-ops and the wrapper adds <1µs overhead per call. When +OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every +model and tool call gets a span with the standard attributes the +plan's dashboards expect. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import AIMessage, ToolMessage + +from app.observability import otel as ot + +if TYPE_CHECKING: # pragma: no cover — type-only + from langchain.agents.middleware.types import ( + ModelRequest, + ModelResponse, + ToolCallRequest, + ) + from langgraph.types import Command + +logger = logging.getLogger(__name__) + + +class OtelSpanMiddleware(AgentMiddleware): + """Emit ``model.call`` and ``tool.call`` OTel spans for every invocation. + + Should be placed near the **outer** end of the middleware list so + that the spans encompass retry/fallback wrapper effects (i.e. ``N`` + model.call spans for ``N`` retry attempts) but inside any concurrency/ + auth gate. Empirically this means **between** ``BusyMutex`` and + ``RetryAfter``. + """ + + def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None: + super().__init__() + self._instrumentation_name = instrumentation_name + + # ------------------------------------------------------------------ + # Model call spans + # ------------------------------------------------------------------ + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[ + [ModelRequest], Awaitable[ModelResponse | AIMessage | Any] + ], + ) -> ModelResponse | AIMessage | Any: + if not ot.is_enabled(): + return await handler(request) + + model_id, provider = _resolve_model_attrs(request) + with ot.model_call_span(model_id=model_id, provider=provider) as sp: + try: + result = await handler(request) + except Exception: + # span context manager records + re-raises + raise + else: + _annotate_model_response(sp, result) + return result + + # ------------------------------------------------------------------ + # Tool call spans + # ------------------------------------------------------------------ + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[ + [ToolCallRequest], Awaitable[ToolMessage | Command[Any]] + ], + ) -> ToolMessage | Command[Any]: + if not ot.is_enabled(): + return await handler(request) + + tool_name = _resolve_tool_name(request) + input_size = _resolve_input_size(request) + + with ot.tool_call_span(tool_name, input_size=input_size) as sp: + result = await handler(request) + _annotate_tool_result(sp, result) + return result + + +# --------------------------------------------------------------------------- +# Attribute helpers (kept defensive; we never want OTel bookkeeping to break +# a real model/tool call). +# --------------------------------------------------------------------------- + + +def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]: + """Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``.""" + model_id: str | None = None + provider: str | None = None + try: + model = getattr(request, "model", None) + if model is None: + return None, None + # langchain BaseChatModel exposes a few different identifiers + for attr in ("model_name", "model", "model_id"): + value = getattr(model, attr, None) + if value: + model_id = str(value) + break + # provider sometimes lives on ``_llm_type`` (legacy) or ``provider`` + for attr in ("provider", "_llm_type"): + value = getattr(model, attr, None) + if value: + provider = str(value) + break + except Exception: # pragma: no cover — defensive + pass + return model_id, provider + + +def _resolve_tool_name(request: Any) -> str: + try: + tool = getattr(request, "tool", None) + if tool is not None: + name = getattr(tool, "name", None) + if isinstance(name, str) and name: + return name + # Fall back to the tool_call dict + call = getattr(request, "tool_call", None) or {} + name = call.get("name") if isinstance(call, dict) else None + if isinstance(name, str) and name: + return name + except Exception: # pragma: no cover — defensive + pass + return "unknown" + + +def _resolve_input_size(request: Any) -> int | None: + try: + call = getattr(request, "tool_call", None) + if not isinstance(call, dict) or not call: + return None + args = call.get("args") + if args is None: + return None + return len(repr(args)) + except Exception: # pragma: no cover — defensive + return None + + +def _annotate_model_response(span: Any, result: Any) -> None: + """Best-effort: attach prompt/completion token counts when available.""" + try: + # ModelResponse may be a dataclass with .result containing AIMessage + msg: Any + if isinstance(result, AIMessage): + msg = result + else: + inner = getattr(result, "result", None) + msg = inner[-1] if isinstance(inner, list) and inner else inner + if msg is None: + return + usage = getattr(msg, "usage_metadata", None) or {} + if isinstance(usage, dict): + if (n := usage.get("input_tokens")) is not None: + span.set_attribute("tokens.prompt", int(n)) + if (n := usage.get("output_tokens")) is not None: + span.set_attribute("tokens.completion", int(n)) + if (n := usage.get("total_tokens")) is not None: + span.set_attribute("tokens.total", int(n)) + tool_calls = getattr(msg, "tool_calls", None) or [] + span.set_attribute("model.tool_calls", len(tool_calls)) + except Exception: # pragma: no cover — defensive + pass + + +def _annotate_tool_result(span: Any, result: Any) -> None: + try: + if isinstance(result, ToolMessage): + content = result.content if isinstance(result.content, str) else repr(result.content) + span.set_attribute("tool.output.size", len(content)) + status = getattr(result, "status", None) + if isinstance(status, str): + span.set_attribute("tool.status", status) + kwargs = getattr(result, "additional_kwargs", None) or {} + if isinstance(kwargs, dict) and kwargs.get("error"): + span.set_attribute("tool.error", True) + except Exception: # pragma: no cover — defensive + pass + + +__all__ = ["OtelSpanMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py new file mode 100644 index 000000000..f59e70bc0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -0,0 +1,344 @@ +""" +PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback. + +Mirrors ``opencode/packages/opencode/src/permission/index.ts`` but uses +SurfSense's existing ``interrupt({type, action, context})`` payload shape +(see ``app/agents/new_chat/tools/hitl.py``) so the frontend keeps +working unchanged. Tier 2.1 in the OpenCode-port plan. + +Operation: +1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``. +2. For each call, the middleware builds a list of ``patterns`` (the + tool name plus any tool-specific patterns from the resolver). It + evaluates each pattern against the layered rulesets and aggregates + the results: ``deny`` > ``ask`` > ``allow``. +3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` + containing a :class:`StreamingError`. +4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply + shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``. + - ``once``: proceed. + - ``always``: also persist allow rules for ``request.always`` patterns. + - ``reject`` w/o feedback: raise :class:`RejectedError`. + - ``reject`` w/ feedback: raise :class:`CorrectedError`. +5. On ``allow``: proceed unchanged. + +The middleware also performs a *pre-model* tool-filter step (the +``before_model`` hook) so globally denied tools are stripped from the +exposed tool list before the model gets to see them. This is +opencode's ``Permission.disabled`` equivalent and dramatically reduces +the chance the model emits a deny-only call. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, +) +from langchain_core.messages import AIMessage, ToolMessage +from langgraph.runtime import Runtime +from langgraph.types import interrupt + +from app.agents.new_chat.errors import ( + CorrectedError, + RejectedError, + StreamingError, +) +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) +from app.observability import otel as ot + +logger = logging.getLogger(__name__) + + +# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of +# patterns to evaluate. The first pattern is conventionally the bare +# tool name; later entries narrow down to specific resources. +PatternResolver = Callable[[dict[str, Any]], list[str]] + + +def _default_pattern_resolver(name: str) -> PatternResolver: + def _resolve(args: dict[str, Any]) -> list[str]: + # Bare name covers the default catch-all; primary-arg fallbacks + # are best added per-tool by callers. + del args + return [name] + + return _resolve + + +class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Allow/deny/ask layer over the agent's tool calls. + + Args: + rulesets: Layered rulesets to evaluate. Earlier entries are + overridden by later ones (last-match-wins). Typical layering: + ``defaults < global < space < thread < runtime_approved``. + pattern_resolvers: Optional per-tool callables that return a list + of patterns to evaluate. When a tool isn't listed, the bare + tool name is used as the only pattern. + runtime_ruleset: Mutable :class:`Ruleset` that the middleware + extends in-place when the user replies ``"always"`` to an + ask interrupt. Reused across all calls in the same agent + instance so newly-allowed rules apply to subsequent calls. + always_emit_interrupt_payload: If True, every ask uses the + SurfSense interrupt wire format (default). Set False to + disable interrupts and treat ``ask`` as ``deny`` for + non-interactive deployments. + """ + + tools = () + + def __init__( + self, + *, + rulesets: list[Ruleset] | None = None, + pattern_resolvers: dict[str, PatternResolver] | None = None, + runtime_ruleset: Ruleset | None = None, + always_emit_interrupt_payload: bool = True, + ) -> None: + super().__init__() + self._static_rulesets: list[Ruleset] = list(rulesets or []) + self._pattern_resolvers: dict[str, PatternResolver] = dict( + pattern_resolvers or {} + ) + self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset( + origin="runtime_approved" + ) + self._emit_interrupt = always_emit_interrupt_payload + + # ------------------------------------------------------------------ + # Tool-filter step (opencode `Permission.disabled` equivalent) + # ------------------------------------------------------------------ + + def _globally_denied(self, tool_name: str) -> bool: + """Return True if a deny rule with no narrowing pattern matches.""" + rules = evaluate_many(tool_name, ["*"], *self._all_rulesets()) + return aggregate_action(rules) == "deny" + + def _all_rulesets(self) -> list[Ruleset]: + return [*self._static_rulesets, self._runtime_ruleset] + + # NOTE: ``before_model`` filtering of the tools list is left to the + # agent factory. This middleware only blocks at execution time — and + # only via the rule-evaluator path, not by mutating ``request.tools``. + # Mutating ``request.tools`` per-call would invalidate provider + # prompt-cache prefixes (see Operational risks: prompt-cache regression). + + # ------------------------------------------------------------------ + # Tool-call evaluation + # ------------------------------------------------------------------ + + def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]: + resolver = self._pattern_resolvers.get( + tool_name, _default_pattern_resolver(tool_name) + ) + try: + patterns = resolver(args or {}) + except Exception: + logger.exception("Pattern resolver for %s raised; using bare name", tool_name) + patterns = [tool_name] + if not patterns: + patterns = [tool_name] + return patterns + + def _evaluate( + self, tool_name: str, args: dict[str, Any] + ) -> tuple[str, list[str], list[Rule]]: + patterns = self._resolve_patterns(tool_name, args) + rules = evaluate_many(tool_name, patterns, *self._all_rulesets()) + action = aggregate_action(rules) + return action, patterns, rules + + # ------------------------------------------------------------------ + # HITL ask flow — SurfSense wire format + # ------------------------------------------------------------------ + + def _raise_interrupt( + self, + *, + tool_name: str, + args: dict[str, Any], + patterns: list[str], + rules: list[Rule], + ) -> dict[str, Any]: + """Block on user approval via SurfSense's ``interrupt`` shape.""" + if not self._emit_interrupt: + return {"decision_type": "reject"} + + # ``params`` (NOT ``args``) is what SurfSense's streaming + # normalizer forwards. Other fields move into ``context``. + payload = { + "type": "permission_ask", + "action": {"tool": tool_name, "params": args or {}}, + "context": { + "patterns": patterns, + "rules": [ + { + "permission": r.permission, + "pattern": r.pattern, + "action": r.action, + } + for r in rules + ], + # Rules of thumb for the frontend: surface the patterns + # the user can promote to "always" with a single reply. + "always": patterns, + }, + } + # Tier 3b: permission.asked + interrupt.raised spans (no-op when + # OTel is disabled). Both fire here so dashboards can correlate + # "we asked X" with "interrupt was actually delivered". + with ot.permission_asked_span( + permission=tool_name, + pattern=patterns[0] if patterns else None, + extra={"permission.patterns": list(patterns)}, + ), ot.interrupt_span(interrupt_type="permission_ask"): + decision = interrupt(payload) + if isinstance(decision, dict): + return decision + # Tolerate a plain string reply ("once", "always", "reject") + if isinstance(decision, str): + return {"decision_type": decision} + return {"decision_type": "reject"} + + def _persist_always( + self, tool_name: str, patterns: list[str] + ) -> None: + """Promote ``always`` reply into runtime allow rules. + + Persistence to ``agent_permission_rules`` is done by the + streaming layer (``stream_new_chat``) once it observes the + ``always`` reply — the middleware just keeps an in-memory + copy so subsequent calls in the same stream see the rule. + """ + for pattern in patterns: + self._runtime_ruleset.rules.append( + Rule(permission=tool_name, pattern=pattern, action="allow") + ) + + # ------------------------------------------------------------------ + # Synthesizing deny -> ToolMessage + # ------------------------------------------------------------------ + + @staticmethod + def _deny_message( + tool_call: dict[str, Any], + rule: Rule, + ) -> ToolMessage: + err = StreamingError( + code="permission_denied", + retryable=False, + suggestion=( + f"rule permission={rule.permission!r} pattern={rule.pattern!r} " + f"blocked this call" + ), + ) + return ToolMessage( + content=( + f"Permission denied: rule {rule.permission}/{rule.pattern} " + f"blocked tool {tool_call.get('name')!r}." + ), + tool_call_id=tool_call.get("id") or "", + name=tool_call.get("name"), + status="error", + additional_kwargs={"error": err.model_dump()}, + ) + + # ------------------------------------------------------------------ + # The hook: aafter_model + # ------------------------------------------------------------------ + + def _process( + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime # unused + messages = state.get("messages") or [] + if not messages: + return None + last = messages[-1] + if not isinstance(last, AIMessage) or not last.tool_calls: + return None + + deny_messages: list[ToolMessage] = [] + kept_calls: list[dict[str, Any]] = [] + any_change = False + + for raw in last.tool_calls: + call = dict(raw) if isinstance(raw, dict) else { + "name": getattr(raw, "name", None), + "args": getattr(raw, "args", {}), + "id": getattr(raw, "id", None), + "type": "tool_call", + } + name = call.get("name") or "" + args = call.get("args") or {} + action, patterns, rules = self._evaluate(name, args) + + if action == "deny": + # Find the deny rule for the suggestion text + deny_rule = next((r for r in rules if r.action == "deny"), rules[0]) + deny_messages.append(self._deny_message(call, deny_rule)) + any_change = True + continue + + if action == "ask": + decision = self._raise_interrupt( + tool_name=name, args=args, patterns=patterns, rules=rules + ) + kind = str(decision.get("decision_type") or "reject").lower() + if kind == "once": + kept_calls.append(call) + elif kind == "always": + self._persist_always(name, patterns) + kept_calls.append(call) + elif kind == "reject": + feedback = decision.get("feedback") + if isinstance(feedback, str) and feedback.strip(): + raise CorrectedError(feedback, tool=name) + raise RejectedError(tool=name, pattern=patterns[0] if patterns else None) + else: + logger.warning( + "Unknown permission decision %r; treating as reject", kind + ) + raise RejectedError(tool=name) + continue + + # allow + kept_calls.append(call) + + if not any_change and len(kept_calls) == len(last.tool_calls): + return None + + updated = last.model_copy(update={"tool_calls": kept_calls}) + result_messages: list[Any] = [updated] + if deny_messages: + result_messages.extend(deny_messages) + return {"messages": result_messages} + + def after_model( # type: ignore[override] + self, state: AgentState, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + return self._process(state, runtime) + + async def aafter_model( # type: ignore[override] + self, state: AgentState, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + return self._process(state, runtime) + + +__all__ = [ + "PatternResolver", + "PermissionMiddleware", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py new file mode 100644 index 000000000..82da6a97c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py @@ -0,0 +1,245 @@ +""" +RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing. + +Why standalone instead of subclassing ``ModelRetryMiddleware``: the upstream +class calls module-level ``calculate_delay`` inline (no overridable +``_calculate_delay`` hook), so a subclass cannot inject Retry-After header +delays without rewriting the loop. Tier 1.4 in the OpenCode-port plan. + +Behaviour: +- Extracts ``Retry-After`` / ``retry-after-ms`` from + ``litellm.exceptions.RateLimitError.response.headers`` (or any exception + exposing a similar shape). +- Sleeps ``max(exponential_backoff, header_delay)`` between retries. +- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` / + ``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or + the LangChain summarization fallback path) handles those instead. +- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry + so ``stream_new_chat`` can forward it to clients as an SSE event. +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import re +import time +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event +from langchain_core.messages import AIMessage + +logger = logging.getLogger(__name__) + +# Names of exception classes for which a retry would not help — context +# overflow needs compaction, auth needs human intervention, etc. Detected +# by class-name substring so we don't have to import LiteLLM/Anthropic +# here (which would tie this module to optional deps). +_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = ( + "ContextWindowExceeded", + "ContextOverflow", + "AuthenticationError", + "InvalidRequestError", + "PermissionDenied", + "InvalidApiKey", + "ContextLimit", +) + + +def _is_non_retryable(exc: BaseException) -> bool: + name = type(exc).__name__ + return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS) + + +def _extract_retry_after_seconds(exc: BaseException) -> float | None: + """Return seconds-to-wait suggested by the provider, if any. + + Looks at ``exc.response.headers`` or ``exc.headers`` for the standard + HTTP ``Retry-After`` header (in seconds) or its millisecond cousin + ``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back + to a regex on the exception message for shapes like + ``"Please retry after 30s"``. + """ + headers: dict[str, Any] | None = None + response = getattr(exc, "response", None) + if response is not None: + headers = getattr(response, "headers", None) + if headers is None: + headers = getattr(exc, "headers", None) + + if isinstance(headers, dict): + # Normalize keys to lowercase for case-insensitive matching + norm = {str(k).lower(): v for k, v in headers.items()} + ms = norm.get("retry-after-ms") + if ms is not None: + try: + return float(ms) / 1000.0 + except (TypeError, ValueError): + pass + seconds = norm.get("retry-after") + if seconds is not None: + try: + return float(seconds) + except (TypeError, ValueError): + pass + + # Last resort: scan the message for "retry after Xs" or "X seconds" + msg = str(exc) + match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE) + if match: + try: + return float(match.group(1)) + except ValueError: + return None + return None + + +def _exponential_delay( + attempt: int, + *, + initial_delay: float, + backoff_factor: float, + max_delay: float, + jitter: bool, +) -> float: + """Compute an exponential-backoff delay with optional ±25% jitter.""" + delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay + delay = min(delay, max_delay) + if jitter and delay > 0: + delay *= 1 + random.uniform(-0.25, 0.25) + return max(delay, 0.0) + + +class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Retry middleware that honors provider-issued Retry-After hints. + + Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware` + when working with LiteLLM/Anthropic/OpenAI providers that surface + rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE + events so the UI can show a friendly "rate limited, retrying in Xs" + indicator. + + Args: + max_retries: Maximum retries after the initial attempt (default 3). + initial_delay: Initial backoff delay in seconds. + backoff_factor: Exponential growth factor for backoff. + max_delay: Cap on per-attempt delay in seconds. + jitter: Whether to add ±25% jitter. + retry_on: Optional callable that returns True for retryable + exceptions. The default retries everything except known + non-retryable classes (context overflow, auth, etc.). + """ + + def __init__( + self, + *, + max_retries: int = 3, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + max_delay: float = 60.0, + jitter: bool = True, + retry_on: Callable[[BaseException], bool] | None = None, + ) -> None: + super().__init__() + self.max_retries = max_retries + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + self.max_delay = max_delay + self.jitter = jitter + self._retry_on: Callable[[BaseException], bool] = retry_on or ( + lambda exc: not _is_non_retryable(exc) + ) + + def _should_retry(self, exc: BaseException) -> bool: + try: + return bool(self._retry_on(exc)) + except Exception: + logger.exception("retry_on callable raised; defaulting to False") + return False + + def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float: + backoff = _exponential_delay( + attempt, + initial_delay=self.initial_delay, + backoff_factor=self.backoff_factor, + max_delay=self.max_delay, + jitter=self.jitter, + ) + header = _extract_retry_after_seconds(exc) or 0.0 + return max(backoff, header) + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: + for attempt in range(self.max_retries + 1): + try: + return handler(request) + except Exception as exc: + if not self._should_retry(exc) or attempt >= self.max_retries: + raise + delay = self._delay_for_attempt(attempt, exc) + try: + dispatch_custom_event( + "surfsense.retrying", + { + "attempt": attempt + 1, + "max_retries": self.max_retries, + "delay_ms": int(delay * 1000), + "reason": type(exc).__name__, + }, + ) + except Exception: + logger.debug("dispatch_custom_event failed; suppressed", exc_info=True) + if delay > 0: + time.sleep(delay) + # Unreachable + raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution") + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: + for attempt in range(self.max_retries + 1): + try: + return await handler(request) + except Exception as exc: + if not self._should_retry(exc) or attempt >= self.max_retries: + raise + delay = self._delay_for_attempt(attempt, exc) + try: + await adispatch_custom_event( + "surfsense.retrying", + { + "attempt": attempt + 1, + "max_retries": self.max_retries, + "delay_ms": int(delay * 1000), + "reason": type(exc).__name__, + }, + ) + except Exception: + logger.debug( + "adispatch_custom_event failed; suppressed", exc_info=True + ) + if delay > 0: + await asyncio.sleep(delay) + raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution") + + +__all__ = [ + "RetryAfterMiddleware", + "_extract_retry_after_seconds", + "_is_non_retryable", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py b/surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py deleted file mode 100644 index 4ddcf334f..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Safe wrapper around deepagents' SummarizationMiddleware. - -Upstream issue --------------- -`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend` -(and its sync counterpart) call -``get_buffer_string(filtered_messages)`` before writing the evicted history -to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string`` -accesses ``m.text`` which iterates ``self.content`` — this raises -``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage`` -has ``content=None`` (common when a model returns *only* tool_calls, seen -frequently with Azure OpenAI ``gpt-5.x`` responses streamed through -LiteLLM). - -The exception aborts the whole agent turn, so the user just sees "Error during -chat" with no assistant response. - -Fix ---- -We subclass ``SummarizationMiddleware`` and override -``_filter_summary_messages`` — the only call site that feeds messages into -``get_buffer_string`` — to return *copies* of messages whose ``content`` is -``None`` with ``content=""``. The originals flowing through the rest of the -agent state are untouched. - -We also expose a drop-in ``create_safe_summarization_middleware`` factory -that mirrors ``deepagents.middleware.summarization.create_summarization_middleware`` -but instantiates our safe subclass. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from deepagents.middleware.summarization import ( - SummarizationMiddleware, - compute_summarization_defaults, -) - -if TYPE_CHECKING: - from deepagents.backends.protocol import BACKEND_TYPES - from langchain_core.language_models import BaseChatModel - from langchain_core.messages import AnyMessage - -logger = logging.getLogger(__name__) - - -def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: - """Return ``msg`` with ``content`` coerced to a non-``None`` value. - - ``get_buffer_string`` reads ``m.text`` which iterates ``self.content``; - when a provider streams back an ``AIMessage`` with only tool_calls and - no text, ``content`` can be ``None`` and the iteration explodes. We - replace ``None`` with an empty string so downstream consumers that only - care about text see an empty body. - - The original message is left untouched — we return a copy via - pydantic's ``model_copy`` when available, otherwise we fall back to - re-setting the attribute on a shallow copy. - """ - - if getattr(msg, "content", "not-missing") is not None: - return msg - - try: - return msg.model_copy(update={"content": ""}) - except AttributeError: - import copy - - new_msg = copy.copy(msg) - try: - new_msg.content = "" - except Exception: # pragma: no cover - defensive - logger.debug( - "Could not sanitize content=None on message of type %s", - type(msg).__name__, - ) - return msg - return new_msg - - -class SafeSummarizationMiddleware(SummarizationMiddleware): - """`SummarizationMiddleware` that tolerates messages with ``content=None``. - - Only ``_filter_summary_messages`` is overridden — this is the single - helper invoked by both the sync and async offload paths immediately - before ``get_buffer_string``. Normalising here means we get coverage - for both without having to copy the (long, rapidly-changing) offload - implementations from upstream. - """ - - def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]: - filtered = super()._filter_summary_messages(messages) - return [_sanitize_message_content(m) for m in filtered] - - -def create_safe_summarization_middleware( - model: BaseChatModel, - backend: BACKEND_TYPES, -) -> SafeSummarizationMiddleware: - """Drop-in replacement for ``create_summarization_middleware``. - - Mirrors the defaults computed by ``deepagents`` but returns our - ``SafeSummarizationMiddleware`` subclass so the - ``content=None`` crash in ``get_buffer_string`` is avoided. - """ - - defaults = compute_summarization_defaults(model) - return SafeSummarizationMiddleware( - model=model, - backend=backend, - trigger=defaults["trigger"], - keep=defaults["keep"], - trim_tokens_to_summarize=None, - truncate_args_settings=defaults["truncate_args_settings"], - ) - - -__all__ = [ - "SafeSummarizationMiddleware", - "create_safe_summarization_middleware", -] diff --git a/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py b/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py new file mode 100644 index 000000000..4c3791c87 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py @@ -0,0 +1,332 @@ +"""Skills backends for SurfSense. + +Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol` +subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`. + +The middleware only needs four methods to load skills from a backend: + +* ``ls_info`` / ``als_info`` — list directories under a source path. +* ``download_files`` / ``adownload_files`` — fetch ``SKILL.md`` bytes. + +Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` …) +default to ``NotImplementedError`` from the base class. They are never reached +by the skills middleware because skill content is rendered into the system +prompt at agent build time, not edited at runtime. + +Two backends are provided: + +* :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from + ``app/agents/new_chat/skills/builtin/``. +* :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over + :class:`KBPostgresBackend` that filters notes under the privileged folder + ``/documents/_skills/``. + +Both backends are intentionally read-only: skill authoring happens out of band +(via filesystem or a search-space-admin route), so we never expose +``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError`` +gives a clean failure mode if anything tries. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import replace +from pathlib import Path +from typing import TYPE_CHECKING + +from deepagents.backends.composite import CompositeBackend +from deepagents.backends.protocol import ( + BackendProtocol, + FileDownloadResponse, + FileInfo, +) +from deepagents.backends.state import StateBackend + +if TYPE_CHECKING: + from langchain.tools import ToolRuntime + + from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend + +logger = logging.getLogger(__name__) + + +# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE. +_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024 + + +def _default_builtin_root() -> Path: + """Return the absolute path to the bundled builtin skills directory. + + Located at ``app/agents/new_chat/skills/builtin/`` relative to this module. + """ + return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve() + + +class BuiltinSkillsBackend(BackendProtocol): + """Read-only disk-backed skills source. + + Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk, + where each skill is its own subdirectory containing a ``SKILL.md`` file:: + + //SKILL.md + + The middleware calls :meth:`als_info` with the source path and expects a + ``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it + calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and + parses YAML frontmatter from the returned ``content`` bytes. + + Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at + prefix ``/skills/builtin/`` means the middleware can issue paths like + ``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to + ``/kb-research/SKILL.md`` before forwarding here. We treat any leading + slash as anchoring at :attr:`root`. + """ + + def __init__(self, root: Path | str | None = None) -> None: + self.root: Path = Path(root).resolve() if root else _default_builtin_root() + if not self.root.exists(): + logger.info( + "BuiltinSkillsBackend root %s does not exist; skills will be empty.", + self.root, + ) + + def _resolve(self, path: str) -> Path: + """Resolve a virtual posix path under :attr:`root`, refusing escapes.""" + bare = path.lstrip("/") + candidate = (self.root / bare).resolve() if bare else self.root + # Refuse symlink/.. traversal that escapes the root. + try: + candidate.relative_to(self.root) + except ValueError as exc: + raise ValueError(f"path {path!r} escapes builtin skills root") from exc + return candidate + + def ls_info(self, path: str) -> list[FileInfo]: + try: + target = self._resolve(path) + except ValueError as exc: + logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc) + return [] + if not target.exists() or not target.is_dir(): + return [] + + infos: list[FileInfo] = [] + # Build virtual paths anchored at "/" because CompositeBackend already + # stripped the route prefix before calling us. + target_virtual = "/" if target == self.root else ( + "/" + str(target.relative_to(self.root)).replace("\\", "/") + ) + for child in sorted(target.iterdir()): + child_virtual = ( + target_virtual.rstrip("/") + "/" + child.name + if target_virtual != "/" + else "/" + child.name + ) + info: FileInfo = { + "path": child_virtual, + "is_dir": child.is_dir(), + } + if child.is_file(): + try: + info["size"] = child.stat().st_size + except OSError: # pragma: no cover - defensive + pass + infos.append(info) + return infos + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + responses: list[FileDownloadResponse] = [] + for p in paths: + try: + target = self._resolve(p) + except ValueError: + responses.append(FileDownloadResponse(path=p, error="invalid_path")) + continue + if not target.exists(): + responses.append(FileDownloadResponse(path=p, error="file_not_found")) + continue + if target.is_dir(): + responses.append(FileDownloadResponse(path=p, error="is_directory")) + continue + try: + # Hard cap to avoid loading rogue mega-files into memory. + size = target.stat().st_size + if size > _MAX_SKILL_FILE_SIZE: + logger.warning( + "Builtin skill file %s exceeds %d bytes; truncating.", + target, + _MAX_SKILL_FILE_SIZE, + ) + with target.open("rb") as fh: + content = fh.read(_MAX_SKILL_FILE_SIZE) + else: + content = target.read_bytes() + except PermissionError: + responses.append(FileDownloadResponse(path=p, error="permission_denied")) + continue + except OSError as exc: # pragma: no cover - defensive + logger.warning("Builtin skill read failed %s: %s", target, exc) + responses.append(FileDownloadResponse(path=p, error="file_not_found")) + continue + responses.append(FileDownloadResponse(path=p, content=content, error=None)) + return responses + + +class SearchSpaceSkillsBackend(BackendProtocol): + """Read-only view of search-space-authored skills. + + Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged + folder ``/documents/_skills/`` (configurable). The folder is intended to be + writable only by search-space admins; this backend never writes. + + The skills middleware expects a layout like:: + + ///SKILL.md + + But the KB stores documents like ``/documents/_skills//SKILL.md``. + We expose the inner namespace by remapping each path. When mounted under + :class:`CompositeBackend` at prefix ``/skills/space/`` the paths the + middleware sees become ``/skills/space//SKILL.md``; the composite + strips ``/skills/space/`` and hands us ``//SKILL.md``, which we + rewrite to ``/documents/_skills//SKILL.md`` before forwarding to the + KB. + + No new database table is needed: the privileged folder convention is + enforced server-side outside of this class. We intentionally swallow any + write/edit attempts (the base class raises ``NotImplementedError``). + """ + + DEFAULT_KB_ROOT: str = "/documents/_skills" + + def __init__( + self, + kb_backend: KBPostgresBackend, + *, + kb_root: str = DEFAULT_KB_ROOT, + ) -> None: + self._kb = kb_backend + # Normalize trailing slash off so we can join cleanly. + self._kb_root = kb_root.rstrip("/") or "/" + + def _to_kb(self, path: str) -> str: + """Rewrite a virtual path into the underlying KB namespace.""" + bare = path.lstrip("/") + if not bare: + return self._kb_root + return f"{self._kb_root}/{bare}" + + def _from_kb(self, kb_path: str) -> str: + """Rewrite a KB path back into our virtual namespace.""" + if not kb_path.startswith(self._kb_root): + return kb_path # pragma: no cover - defensive + rel = kb_path[len(self._kb_root) :] + return rel if rel.startswith("/") else "/" + rel + + def ls_info(self, path: str) -> list[FileInfo]: + # KBPostgresBackend exposes only the async API meaningfully; the sync + # path falls back to ``asyncio.to_thread(...)`` in the base class. We + # keep this stub to satisfy abstract resolution; the middleware calls + # ``als_info``. + raise NotImplementedError("SearchSpaceSkillsBackend is async-only") + + async def als_info(self, path: str) -> list[FileInfo]: + kb_path = self._to_kb(path) + try: + infos = await self._kb.als_info(kb_path) + except Exception as exc: # pragma: no cover - defensive + logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc) + return [] + remapped: list[FileInfo] = [] + for info in infos: + kb_p = info.get("path", "") + if not kb_p.startswith(self._kb_root): + continue + remapped.append({**info, "path": self._from_kb(kb_p)}) + return remapped + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + raise NotImplementedError("SearchSpaceSkillsBackend is async-only") + + async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]: + kb_paths = [self._to_kb(p) for p in paths] + responses = await self._kb.adownload_files(kb_paths) + # Re-map response paths back to the virtual namespace so the middleware + # correlates them to the input list correctly. + remapped: list[FileDownloadResponse] = [] + for original, resp in zip(paths, responses, strict=True): + remapped.append(replace(resp, path=original)) + return remapped + + +SKILLS_BUILTIN_PREFIX = "/skills/builtin/" +SKILLS_SPACE_PREFIX = "/skills/space/" + + +def build_skills_backend_factory( + *, + builtin_root: Path | str | None = None, + search_space_id: int | None = None, +) -> Callable[[ToolRuntime], BackendProtocol]: + """Return a runtime-aware factory for the skills :class:`CompositeBackend`. + + When ``search_space_id`` is provided the composite includes a + :class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh + per-runtime :class:`KBPostgresBackend`, mirroring how + :func:`build_backend_resolver` constructs the main filesystem backend. + + When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit + tests) only the bundled :class:`BuiltinSkillsBackend` is exposed. + + Returning a factory rather than a fixed instance is intentional: the + underlying KB backend depends on per-call ``ToolRuntime`` state + (``staged_dirs``, ``files`` cache, runtime config), so a single shared + instance cannot serve multiple concurrent agent runs. + """ + builtin = BuiltinSkillsBackend(builtin_root) + + if search_space_id is None: + def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol: + # Default StateBackend is intentionally inert: any path outside the + # ``/skills/builtin/`` route resolves to an empty per-runtime state + # so the SkillsMiddleware can iterate sources without raising. + return CompositeBackend( + default=StateBackend(runtime), + routes={SKILLS_BUILTIN_PREFIX: builtin}, + ) + return _factory_builtin_only + + def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol: + # Imported lazily to avoid a hard dependency at module import time: + # ``KBPostgresBackend`` pulls in DB models, which are unnecessary for + # the unit-tested builtin path. + from app.agents.new_chat.middleware.kb_postgres_backend import ( + KBPostgresBackend, + ) + + kb = KBPostgresBackend(search_space_id, runtime) + space = SearchSpaceSkillsBackend(kb) + return CompositeBackend( + default=StateBackend(runtime), + routes={ + SKILLS_BUILTIN_PREFIX: builtin, + SKILLS_SPACE_PREFIX: space, + }, + ) + + return _factory_with_space + + +def default_skills_sources() -> list[str]: + """Return the canonical source list for SkillsMiddleware (built-in then space).""" + return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX] + + +__all__ = [ + "SKILLS_BUILTIN_PREFIX", + "SKILLS_SPACE_PREFIX", + "BuiltinSkillsBackend", + "SearchSpaceSkillsBackend", + "build_skills_backend_factory", + "default_skills_sources", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py new file mode 100644 index 000000000..6c3bc674d --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py @@ -0,0 +1,190 @@ +""" +ToolCallNameRepairMiddleware — two-stage tool-name repair. + +Mirrors ``opencode/packages/opencode/src/session/llm.ts:339-358`` plus +``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.7 in the +OpenCode-port plan. + +Operation: +1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in + the registry but ``name.lower()`` is, rewrite in place. Catches + models that emit ``Search`` instead of ``search``. +2. **Stage 2 — invalid fallback:** if still unmatched, rewrite the call + to ``invalid`` with ``args={"tool": original_name, "error": }`` + so the registered :func:`invalid_tool` returns the error to the model + for self-correction. + +Distinct from :class:`deepagents.middleware.PatchToolCallsMiddleware`, +which patches *dangling* tool calls (no matching ToolMessage) — that +class does not handle the wrong-name case at all. +""" + +from __future__ import annotations + +import difflib +import logging +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, +) +from langchain_core.messages import AIMessage +from langgraph.runtime import Runtime + +from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME + +logger = logging.getLogger(__name__) + + +def _coerce_existing_tool_call(call: Any) -> dict[str, Any]: + """Normalize a tool call entry to a mutable dict.""" + if isinstance(call, dict): + return dict(call) + return { + "name": getattr(call, "name", None), + "args": getattr(call, "args", {}), + "id": getattr(call, "id", None), + "type": "tool_call", + } + + +class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Two-stage tool-name repair on the most recent ``AIMessage``. + + Args: + registered_tool_names: Set of canonically-registered tool names. + ``invalid`` should be in this set so the fallback dispatches. + fuzzy_match_threshold: Optional ``difflib`` ratio (0–1) for the + fuzzy-match step that runs *between* lowercase and invalid. + Set to ``None`` to disable fuzzy matching (opencode parity). + """ + + def __init__( + self, + *, + registered_tool_names: set[str], + fuzzy_match_threshold: float | None = 0.85, + ) -> None: + super().__init__() + self._registered = set(registered_tool_names) + self._registered_lower = {name.lower(): name for name in self._registered} + self._fuzzy_threshold = fuzzy_match_threshold + self.tools = [] + + def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]: + """Allow runtime overrides to expand the set (e.g. dynamic MCP tools).""" + ctx_tools = getattr(runtime.context, "registered_tool_names", None) + if isinstance(ctx_tools, (set, frozenset)): + return self._registered | set(ctx_tools) + if isinstance(ctx_tools, (list, tuple)): + return self._registered | set(ctx_tools) + return self._registered + + def _repair_one( + self, + call: dict[str, Any], + registered: set[str], + ) -> dict[str, Any]: + name = call.get("name") + if not isinstance(name, str): + return call + + if name in registered: + return call + + # Stage 1 — lowercase + lowered = name.lower() + if lowered in registered: + call["name"] = lowered + metadata = dict(call.get("response_metadata") or {}) + metadata.setdefault("repair", "lowercase") + call["response_metadata"] = metadata + return call + + # Optional fuzzy step (off by default for opencode parity) + if self._fuzzy_threshold is not None: + close = difflib.get_close_matches( + name, registered, n=1, cutoff=self._fuzzy_threshold + ) + if close: + call["name"] = close[0] + metadata = dict(call.get("response_metadata") or {}) + metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}") + call["response_metadata"] = metadata + return call + + # Stage 2 — invalid fallback + if INVALID_TOOL_NAME in registered: + original_args = call.get("args") or {} + error_msg = ( + f"Tool name '{name}' is not registered. " + f"Original arguments were: {original_args!r}." + ) + call["name"] = INVALID_TOOL_NAME + call["args"] = {"tool": name, "error": error_msg} + metadata = dict(call.get("response_metadata") or {}) + metadata.setdefault("repair", f"invalid_fallback:{name}") + call["response_metadata"] = metadata + else: + logger.warning( + "Could not repair unknown tool call %r; 'invalid' tool not registered", + name, + ) + return call + + def _maybe_repair( + self, + message: AIMessage, + registered: set[str], + ) -> AIMessage | None: + if not message.tool_calls: + return None + + new_calls: list[dict[str, Any]] = [] + any_changed = False + for raw in message.tool_calls: + call = _coerce_existing_tool_call(raw) + before = (call.get("name"), call.get("args")) + repaired = self._repair_one(call, registered) + after = (repaired.get("name"), repaired.get("args")) + if before != after: + any_changed = True + new_calls.append(repaired) + + if not any_changed: + return None + + return message.model_copy(update={"tool_calls": new_calls}) + + def after_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + messages = state.get("messages") or [] + if not messages: + return None + last = messages[-1] + if not isinstance(last, AIMessage): + return None + + registered = self._registered_for_runtime(runtime) + repaired = self._maybe_repair(last, registered) + if repaired is None: + return None + return {"messages": [repaired]} + + async def aafter_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + return self.after_model(state, runtime) + + +__all__ = [ + "ToolCallNameRepairMiddleware", +] diff --git a/surfsense_backend/app/agents/new_chat/permissions.py b/surfsense_backend/app/agents/new_chat/permissions.py new file mode 100644 index 000000000..50a0cfbdc --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/permissions.py @@ -0,0 +1,204 @@ +""" +Wildcard pattern matching + rule evaluation for the SurfSense permission system. + +Mirrors ``opencode/packages/opencode/src/permission/evaluate.ts`` and +``opencode/packages/opencode/src/util/wildcard.ts`` precisely: + +- ``Wildcard.match`` matches both the ``permission`` and the ``pattern`` + fields of a rule against the requested ``(permission, pattern)`` pair. + ``*`` matches any segment, ``**`` matches across separators. +- The evaluator runs ``findLast`` over the **flattened** list of rules + from all rulesets — last matching rule wins. +- The default fallback is ``ask`` (NOT deny), matching opencode. +- Multi-pattern requests AND together: if ANY pattern resolves to + ``deny``, the whole request is denied; if ANY needs ``ask``, an + interrupt is raised; only when all patterns ``allow`` does the + request proceed. + +Tier 2.1 in the OpenCode-port plan. +""" + +from __future__ import annotations + +import re +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Literal + +RuleAction = Literal["allow", "deny", "ask"] + + +@dataclass(frozen=True) +class Rule: + """A single permission rule. + + Attributes: + permission: A wildcard-matched permission identifier + (e.g. ``"edit"``, ``"linear_*"``, ``"mcp:*"``, + ``"doom_loop"``). Anchored at start AND end of the input. + pattern: A wildcard-matched pattern over the request payload + (e.g. ``"/documents/secrets/**"``, ``"page_id=123"``, + ``"*"``). Anchored at start AND end. + action: One of ``"allow"`` / ``"deny"`` / ``"ask"``. + """ + + permission: str + pattern: str + action: RuleAction + + +@dataclass +class Ruleset: + """A list of rules with an associated origin used for debugging.""" + + rules: list[Rule] = field(default_factory=list) + origin: str = "unknown" # e.g. "defaults", "global", "space", "thread", "runtime" + + +# ----------------------------------------------------------------------------- +# Wildcard matcher +# ----------------------------------------------------------------------------- + + +_GLOB_TOKEN = re.compile(r"\*\*|\*|[^*]+") + + +def _wildcard_to_regex(pattern: str) -> re.Pattern[str]: + """Translate an opencode-style wildcard pattern to a compiled regex. + + Rules: + - ``**`` matches any sequence of any characters (including separators). + - ``*`` matches any sequence of characters that does **not** include + the path separator ``/`` — same as glob. + - All other characters match literally. + - The pattern is anchored at both ends (``^...$``). + """ + parts: list[str] = ["^"] + for token in _GLOB_TOKEN.findall(pattern): + if token == "**": + parts.append(r".*") + elif token == "*": + parts.append(r"[^/]*") + else: + parts.append(re.escape(token)) + parts.append("$") + return re.compile("".join(parts)) + + +_REGEX_CACHE: dict[str, re.Pattern[str]] = {} + + +def wildcard_match(value: str, pattern: str) -> bool: + """Return True if ``value`` matches the wildcard ``pattern``. + + Special case: a bare ``"*"`` pattern matches any value, including + those containing ``/`` separators. This mirrors opencode's + ``Wildcard.match`` short-circuit and matches the convention that + ``pattern="*"`` means "any pattern" in permission rules. + """ + if pattern == "*": + return True + compiled = _REGEX_CACHE.get(pattern) + if compiled is None: + compiled = _wildcard_to_regex(pattern) + _REGEX_CACHE[pattern] = compiled + return compiled.match(value) is not None + + +# ----------------------------------------------------------------------------- +# Evaluator +# ----------------------------------------------------------------------------- + + +def evaluate( + permission: str, + pattern: str, + *rulesets: Ruleset | Iterable[Rule], +) -> Rule: + """Find the last rule matching ``(permission, pattern)`` from ``rulesets``. + + Mirrors opencode ``permission/evaluate.ts:9-15`` precisely: + - Flatten rulesets in argument order. + - Walk the flat list **in reverse**. + - First reverse-match wins (i.e. the last specified rule wins). + - When no rule matches, default to ``Rule(permission, "*", "ask")``. + + Args: + permission: The permission identifier being requested + (e.g. tool name, ``"edit"``, ``"doom_loop"``). + pattern: The request-specific pattern (e.g. file path, + primary arg value). Use ``"*"`` when no specific pattern + applies. + *rulesets: Layered rulesets, applied earliest to latest. Later + rulesets override earlier ones. + + Returns: + The matched :class:`Rule`, or the default ask fallback. + """ + flat: list[Rule] = [] + for rs in rulesets: + if isinstance(rs, Ruleset): + flat.extend(rs.rules) + else: + flat.extend(rs) + + for rule in reversed(flat): + if wildcard_match(permission, rule.permission) and wildcard_match( + pattern, rule.pattern + ): + return rule + + return Rule(permission=permission, pattern="*", action="ask") + + +def evaluate_many( + permission: str, + patterns: Iterable[str], + *rulesets: Ruleset | Iterable[Rule], +) -> list[Rule]: + """Evaluate ``permission`` against each of ``patterns`` (multi-pattern AND). + + Returns the list of resolved rules in the same order as ``patterns``. + The caller is responsible for combining the results — opencode-style + multi-pattern AND collapses ``deny`` first, then ``ask``, then + ``allow``. + """ + return [evaluate(permission, p, *rulesets) for p in patterns] + + +def aggregate_action(rules: Iterable[Rule]) -> RuleAction: + """Collapse a list of per-pattern rules into one action. + + Order: + 1. If any rule is ``deny`` -> ``deny``. + 2. Else if any rule is ``ask`` -> ``ask``. + 3. Else if at least one rule is ``allow`` -> ``allow``. + 4. Else (empty input) -> ``ask`` (safe default mirroring ``evaluate``). + + Mirrors opencode's behavior in ``permission/index.ts:180-272``. + """ + saw_ask = False + saw_allow = False + for rule in rules: + if rule.action == "deny": + return "deny" + if rule.action == "ask": + saw_ask = True + elif rule.action == "allow": + saw_allow = True + if saw_ask: + return "ask" + if saw_allow: + return "allow" + return "ask" + + +__all__ = [ + "Rule", + "RuleAction", + "Ruleset", + "aggregate_action", + "evaluate", + "evaluate_many", + "wildcard_match", +] diff --git a/surfsense_backend/app/agents/new_chat/plugin_loader.py b/surfsense_backend/app/agents/new_chat/plugin_loader.py new file mode 100644 index 000000000..426e28041 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/plugin_loader.py @@ -0,0 +1,157 @@ +"""Entry-point based plugin loader for SurfSense agent middleware. + +The realization in the Tier 6 plan: LangChain's :class:`AgentMiddleware` ABC +already covers the practical surface most plugins need (``before_agent`` / +``before_model`` / ``wrap_tool_call`` / their async counterparts), so a +SurfSense-specific plugin protocol is unnecessary. + +A plugin is therefore just an installable Python package that registers a +factory callable under the ``surfsense.plugins`` entry-point group: + +.. code-block:: toml + + # in a plugin package's pyproject.toml + [project.entry-points."surfsense.plugins"] + year_substituter = "my_plugin:make_middleware" + +The factory has the signature ``Callable[[PluginContext], AgentMiddleware]``. +It receives a small, sanitized :class:`PluginContext` with the IDs and the +LLM the plugin is allowed to talk to — and **never** raw secrets, DB +sessions, or other connectors. + +## Trust model + +Plugins are loaded **only if** their entry-point ``name`` appears in +``allowed_plugins`` (admin-controlled, sourced from +``global_llm_config.yaml`` or :func:`load_allowed_plugin_names_from_env`). +There is **no env-driven auto-load**. A plugin failure is logged and +isolated; it does not break agent construction. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable +from importlib.metadata import entry_points +from typing import TYPE_CHECKING + +from langchain.agents.middleware import AgentMiddleware + +if TYPE_CHECKING: # pragma: no cover - type-only + from langchain_core.language_models import BaseChatModel + + from app.db import ChatVisibility + + +logger = logging.getLogger(__name__) + + +PLUGIN_ENTRY_POINT_GROUP = "surfsense.plugins" + + +class PluginContext(dict): + """Sanitized DI bag handed to each plugin factory. + + Backed by ``dict`` so plugins can inspect the keys they care about + without coupling to a concrete dataclass shape. Required keys: + + * ``search_space_id`` (int) + * ``user_id`` (str | None) + * ``thread_visibility`` (:class:`app.db.ChatVisibility`) + * ``llm`` (:class:`langchain_core.language_models.BaseChatModel`) + + The context **never** carries DB sessions, raw secrets, or other + connectors. If a future plugin genuinely needs DB access, that + integration goes through a rate-limited service interface, not + through this bag. + """ + + @classmethod + def build( + cls, + *, + search_space_id: int, + user_id: str | None, + thread_visibility: ChatVisibility, + llm: BaseChatModel, + ) -> PluginContext: + return cls( + search_space_id=search_space_id, + user_id=user_id, + thread_visibility=thread_visibility, + llm=llm, + ) + + +def load_plugin_middlewares( + ctx: PluginContext, + allowed_plugin_names: Iterable[str], +) -> list[AgentMiddleware]: + """Discover, allowlist-filter, and instantiate plugin middleware. + + For each entry-point in :data:`PLUGIN_ENTRY_POINT_GROUP` whose name is + in ``allowed_plugin_names``, load the factory and call it with ``ctx``. + The factory's return value must be an :class:`AgentMiddleware` instance; + anything else is logged and skipped. + + Errors are isolated — a plugin that raises during ``ep.load()`` or + factory invocation is logged at ``ERROR`` and ignored. Agent + construction continues with whatever plugins did succeed. + """ + allowed = {name for name in allowed_plugin_names if name} + if not allowed: + return [] + + out: list[AgentMiddleware] = [] + try: + eps = entry_points(group=PLUGIN_ENTRY_POINT_GROUP) + except Exception: # pragma: no cover - defensive (entry_points is robust) + logger.exception("Failed to enumerate plugin entry points") + return [] + + for ep in eps: + if ep.name not in allowed: + logger.info("Skipping non-allowlisted plugin %s", ep.name) + continue + try: + factory = ep.load() + except Exception: + logger.exception("Failed to load plugin %s", ep.name) + continue + try: + mw = factory(ctx) + except Exception: + logger.exception("Plugin %s factory raised", ep.name) + continue + if not isinstance(mw, AgentMiddleware): + logger.warning( + "Plugin %s returned %s, expected AgentMiddleware; skipping", + ep.name, + type(mw).__name__, + ) + continue + out.append(mw) + logger.info("Loaded plugin %s as %s", ep.name, type(mw).__name__) + return out + + +def load_allowed_plugin_names_from_env() -> set[str]: + """Read ``SURFSENSE_ALLOWED_PLUGINS`` (comma-separated) into a set. + + Provided as a thin convenience for deployments that don't surface plugins + through ``global_llm_config.yaml`` yet. Whitespace is stripped and empty + entries are dropped. + """ + raw = os.environ.get("SURFSENSE_ALLOWED_PLUGINS", "").strip() + if not raw: + return set() + return {token.strip() for token in raw.split(",") if token.strip()} + + +__all__ = [ + "PLUGIN_ENTRY_POINT_GROUP", + "PluginContext", + "load_allowed_plugin_names_from_env", + "load_plugin_middlewares", +] diff --git a/surfsense_backend/app/agents/new_chat/plugins/__init__.py b/surfsense_backend/app/agents/new_chat/plugins/__init__.py new file mode 100644 index 000000000..cef6bd367 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/plugins/__init__.py @@ -0,0 +1,6 @@ +"""Reference plugins bundled with SurfSense. + +These plugins are intentionally small and demonstrative. They are NOT +auto-loaded — they ship as examples that a deployment can opt into via +``global_llm_config.yaml`` or ``SURFSENSE_ALLOWED_PLUGINS``. +""" diff --git a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py new file mode 100644 index 000000000..927d533d5 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py @@ -0,0 +1,87 @@ +"""Reference plugin: substitute ``{{year}}`` in tool descriptions. + +Mirrors the OpenCode ``chat.system.transform`` example. Demonstrates the +:meth:`AgentMiddleware.awrap_tool_call` hook -- the plugin sees every tool +invocation and can rewrite the request *or* the result. This particular +plugin is read-only and only transforms the *description* the user might +see in error messages (no request mutation). + +The plugin is built as a factory function so the entry-point loader can +inject :class:`PluginContext` (containing the agent's LLM, search-space +ID, etc.). The factory signature +``Callable[[PluginContext], AgentMiddleware]`` is the only contract -- +SurfSense doesn't define a custom plugin protocol on top of LangChain's +:class:`AgentMiddleware`. + +Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't +need this -- it's already on the import path):: + + [project.entry-points."surfsense.plugins"] + year_substituter = "app.agents.new_chat.plugins.year_substituter:make_middleware" +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import AgentMiddleware + +if TYPE_CHECKING: # pragma: no cover - type-only + from langchain.agents.middleware.types import ToolCallRequest + from langchain_core.messages import ToolMessage + from langgraph.types import Command + + from app.agents.new_chat.plugin_loader import PluginContext + + +logger = logging.getLogger(__name__) + + +class _YearSubstituterMiddleware(AgentMiddleware): + """Replace ``{{year}}`` in the result text with the current UTC year.""" + + tools = () + + def __init__(self, year: int | None = None) -> None: + super().__init__() + self._year = str(year if year is not None else datetime.now(UTC).year) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[ + [ToolCallRequest], Awaitable[ToolMessage | Command[Any]] + ], + ) -> ToolMessage | Command[Any]: + result = await handler(request) + try: + from langchain_core.messages import ToolMessage + + if isinstance(result, ToolMessage) and isinstance(result.content, str): + if "{{year}}" in result.content: + new_text = result.content.replace("{{year}}", self._year) + result = ToolMessage( + content=new_text, + tool_call_id=result.tool_call_id, + id=result.id, + name=result.name, + status=result.status, + artifact=result.artifact, + ) + except Exception: # pragma: no cover - defensive + logger.exception("year_substituter plugin failed; passing original result") + return result + + +def make_middleware(ctx: PluginContext) -> AgentMiddleware: + """Plugin factory used by :func:`load_plugin_middlewares`.""" + # Plugin is intentionally small so it has no state to threading-protect + # and ignores ``ctx`` beyond demonstrating that the loader passes it in. + _ = ctx + return _YearSubstituterMiddleware() + + +__all__ = ["make_middleware"] diff --git a/surfsense_backend/app/agents/new_chat/prompts/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/__init__.py new file mode 100644 index 000000000..c91bb8a0b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/__init__.py @@ -0,0 +1,7 @@ +"""SurfSense agent prompt fragments. + +The prompt is composed at runtime by :mod:`composer` from the markdown +fragments under ``base/``, ``providers/``, ``tools/``, ``examples/``, and +``routing/``. ``system_prompt.py`` is now a thin wrapper that delegates +to :func:`composer.compose_system_prompt`. +""" diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md new file mode 100644 index 000000000..88554ad4e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md @@ -0,0 +1,7 @@ +You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base. + +Today's date (UTC): {resolved_today} + +When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. + +NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md new file mode 100644 index 000000000..5fd56ae1b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md @@ -0,0 +1,9 @@ +You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base. + +In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers. + +Today's date (UTC): {resolved_today} + +When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. + +NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md b/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md new file mode 100644 index 000000000..8288886e9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md @@ -0,0 +1,16 @@ + +IMPORTANT: Citations are DISABLED for this configuration. + +DO NOT include any citations in your responses. Specifically: +1. Do NOT use the [citation:chunk_id] format anywhere in your response. +2. Do NOT reference document IDs, chunk IDs, or source IDs. +3. Simply provide the information naturally without any citation markers. +4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly. + +When answering questions based on documents from the knowledge base: +- Present the information directly and confidently +- Do not mention that information comes from specific documents or chunks +- Integrate facts naturally into your response without attribution markers + +Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md b/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md new file mode 100644 index 000000000..56291bf3e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md @@ -0,0 +1,90 @@ + +CRITICAL CITATION REQUIREMENTS: + +1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `` tag inside ``. +2. Make sure ALL factual statements from the documents have proper citations. +3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2]. +4. You MUST use the exact chunk_id values from the `` attributes. Do not create your own citation numbers. +5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value. +6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags. +7. Do not return citations as clickable links. +8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only. +9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting. +10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `` tags. +11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up. + + +The documents you receive are structured like this: + +**Knowledge base documents (numeric chunk IDs):** + + + 42 + GITHUB_CONNECTOR + <![CDATA[Some repo / file / issue title]]> + + + + + + + + + + +**Web search results (URL chunk IDs):** + + + WEB_SEARCH + <![CDATA[Some web search result]]> + + + + + + + + +IMPORTANT: You MUST cite using the EXACT chunk ids from the `` tags. +- For knowledge base documents, chunk ids are numeric (e.g. 123, 124) or prefixed (e.g. doc-45). +- For live web search results, chunk ids are URLs (e.g. https://example.com/article). +Do NOT cite document_id. Always use the chunk id. + + + +- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `` tag +- Citations should appear at the end of the sentence containing the information they support +- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] +- No need to return references section. Just citations in answer. +- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format +- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only +- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess +- Copy the EXACT chunk id from the XML - if it says ``, use [citation:doc-123] +- If the chunk id is a URL like ``, use [citation:https://example.com/page] + + + +CORRECT citation formats: +- [citation:5] (numeric chunk ID from knowledge base) +- [citation:doc-123] (for Surfsense documentation chunks) +- [citation:https://example.com/article] (URL chunk ID from web search results) +- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations) + +INCORRECT citation formats (DO NOT use): +- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense)) +- Using parentheses around brackets: ([citation:5]) +- Using hyperlinked text: [link to source 5](https://example.com) +- Using footnote style: ... library¹ +- Making up source IDs when source_id is unknown +- Using old IEEE format: [1], [2], [3] +- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5] + + + +Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5]. + +According to web search results, the key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:https://docs.python.org/3/library/asyncio.html]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources. + +However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead. + + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md new file mode 100644 index 000000000..9cc767e7e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md @@ -0,0 +1,15 @@ + +CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: +- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs. +- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission. +- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: + 1. Inform the user that you could not find relevant information in their knowledge base. + 2. Ask the user: "Would you like me to answer from my general knowledge instead?" + 3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes. +- This policy does NOT apply to: + * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") + * Formatting, summarization, or analysis of content already present in the conversation + * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") + * Tool-usage actions like generating reports, podcasts, images, or scraping webpages + * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md new file mode 100644 index 000000000..1d806dbae --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md @@ -0,0 +1,15 @@ + +CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: +- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs. +- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission. +- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: + 1. Inform the team that you could not find relevant information in the shared knowledge base. + 2. Ask: "Would you like me to answer from my general knowledge instead?" + 3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes. +- This policy does NOT apply to: + * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") + * Formatting, summarization, or analysis of content already present in the conversation + * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") + * Tool-usage actions like generating reports, podcasts, images, or scraping webpages + * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md new file mode 100644 index 000000000..8f7da14f8 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md @@ -0,0 +1,6 @@ + +IMPORTANT — After understanding each user message, ALWAYS check: does this message +reveal durable facts about the user (role, interests, preferences, projects, +background, or standing instructions)? If yes, you MUST call update_memory +alongside your normal response — do not defer this to a later turn. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md new file mode 100644 index 000000000..61d89cc5d --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md @@ -0,0 +1,6 @@ + +IMPORTANT — After understanding each user message, ALWAYS check: does this message +reveal durable facts about the team (decisions, conventions, architecture, processes, +or key facts)? If yes, you MUST call update_memory alongside your normal response — +do not defer this to a later turn. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md b/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md new file mode 100644 index 000000000..77be4d87c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md @@ -0,0 +1,39 @@ + +Some service tools require identifiers or context you do not have (account IDs, +workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw +IDs or technical identifiers — they cannot memorise them. + +Instead, follow this discovery pattern: +1. Call a listing/discovery tool to find available options. +2. ONE result → use it silently, no question to the user. +3. MULTIPLE results → present the options by their display names and let the + user choose. Never show raw UUIDs — always use friendly names. + +Discovery tools by level: +- Which account/workspace? → get_connected_accounts("") +- Which Jira site (cloudId)? → getAccessibleAtlassianResources +- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) +- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) +- Which channel? → slack_search_channels +- Which base? → list_bases +- Which table? → list_tables_for_base (after resolving baseId) +- Which task? → clickup_search +- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) + +For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to +obtain the cloudId, then pass it to other Jira tools. When creating an issue, +chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. +If there is only one option at each step, use it silently. If multiple, present +friendly names. + +Chain discovery when needed — e.g. for Airtable records: list_bases → pick +base → list_tables_for_base → pick table → list_records_for_table. + +MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for +the same service, tool names are prefixed to avoid collisions — e.g. +linear_25_list_issues and linear_30_list_issues instead of two list_issues. +Each prefixed tool's description starts with [Account: ] so you +know which account it targets. Use get_connected_accounts("") to see +the full list of accounts with their connector IDs and display names. +When only one account is connected, tools have their normal unprefixed names. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md new file mode 100644 index 000000000..ec667bf88 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md @@ -0,0 +1,16 @@ + +CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. +Their data is NEVER in the knowledge base. You MUST call their tools immediately — never +say "I don't see it in the knowledge base" or ask the user if they want you to check. +Ignore any knowledge base results for these services. + +When to use which tool: +- Linear (issues) → list_issues, get_issue, save_issue (create/update) +- ClickUp (tasks) → clickup_search, clickup_get_task +- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue +- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread +- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table +- Knowledge base content (Notion, GitHub, files, notes) → automatically searched +- Real-time public web data → call web_search +- Reading a specific webpage → call scrape_webpage + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md new file mode 100644 index 000000000..48b7a990b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md @@ -0,0 +1,16 @@ + +CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. +Their data is NEVER in the knowledge base. You MUST call their tools immediately — never +say "I don't see it in the knowledge base" or ask if they want you to check. +Ignore any knowledge base results for these services. + +When to use which tool: +- Linear (issues) → list_issues, get_issue, save_issue (create/update) +- ClickUp (tasks) → clickup_search, clickup_get_task +- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue +- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread +- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table +- Knowledge base content (Notion, GitHub, files, notes) → automatically searched +- Real-time public web data → call web_search +- Reading a specific webpage → call scrape_webpage + diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/agents/new_chat/prompts/composer.py new file mode 100644 index 000000000..44060f75f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/composer.py @@ -0,0 +1,359 @@ +""" +Prompt composer for the SurfSense ``new_chat`` agent. + +This module assembles the agent's system prompt from the markdown fragments +under :mod:`app.agents.new_chat.prompts`. It replaces the monolithic +``system_prompt.py`` with a clean, fragment-based composition: + +:: + + prompts/ + base/ # agent identity, KB policy, tool routing, … + providers/ # provider-specific tweaks (anthropic, gpt5, …) + tools/ # one ``.md`` per tool + examples/ # one ``.md`` per tool with call examples + routing/ # connector-specific routing notes (linear, slack, …) + +Tier 3a in the OpenCode-port plan. + +Backwards compatibility +======================= + +``system_prompt.py`` re-exports :func:`compose_system_prompt` and wraps it +in functions with the same signatures as the legacy +``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` so +existing call sites do not change. +""" + +from __future__ import annotations + +import re +from collections.abc import Iterable +from datetime import UTC, datetime +from importlib import resources + +from app.db import ChatVisibility + +# ----------------------------------------------------------------------------- +# Provider variant detection +# ----------------------------------------------------------------------------- + +ProviderVariant = str # "anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default" + +_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE) +_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE) +_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE) +_GOOGLE_RE = re.compile(r"\bgemini\b", re.IGNORECASE) + + +def detect_provider_variant(model_name: str | None) -> ProviderVariant: + """Pick a provider-specific prompt variant from a model id string. + + Heuristic match on the model id; returns ``"default"`` when nothing + matches so the composer can fall back to the empty placeholder file. + """ + if not model_name: + return "default" + name = model_name.strip() + if _OPENAI_REASONING_RE.search(name): + return "openai_reasoning" + if _OPENAI_CLASSIC_RE.search(name): + return "openai_classic" + if _ANTHROPIC_RE.search(name): + return "anthropic" + if _GOOGLE_RE.search(name): + return "google" + return "default" + + +# ----------------------------------------------------------------------------- +# Fragment loading +# ----------------------------------------------------------------------------- + + +_PROMPTS_PACKAGE = "app.agents.new_chat.prompts" + + +def _read_fragment(subpath: str) -> str: + """Read a fragment file from the ``prompts/`` resource tree. + + Returns the raw contents stripped of any single trailing newline so + composition can append explicit separators without compounding blank + lines. Missing files return an empty string so optional fragments + (e.g. provider hints) act as no-ops. + """ + parts = subpath.split("/") + try: + ref = resources.files(_PROMPTS_PACKAGE).joinpath(*parts) + if not ref.is_file(): + return "" + text = ref.read_text(encoding="utf-8") + except (FileNotFoundError, ModuleNotFoundError): + return "" + if text.endswith("\n"): + text = text[:-1] + return text + + +# ----------------------------------------------------------------------------- +# Tool ordering + memory variant resolution +# ----------------------------------------------------------------------------- + + +# Ordered for reading flow: fundamentals first, then artifact generators, +# then memory at the end (mirrors the legacy ``_ALL_TOOL_NAMES_ORDERED``). +ALL_TOOL_NAMES_ORDERED: tuple[str, ...] = ( + "search_surfsense_docs", + "web_search", + "generate_podcast", + "generate_video_presentation", + "generate_report", + "generate_resume", + "generate_image", + "scrape_webpage", + "update_memory", +) + + +_MEMORY_VARIANT_TOOLS: frozenset[str] = frozenset({"update_memory"}) + + +def _tool_fragment_path(tool_name: str, variant: str) -> str: + """Resolve a tool's instruction fragment path. + + Tools listed in :data:`_MEMORY_VARIANT_TOOLS` switch on the conversation + visibility and load ``tools/_.md``; everything else + falls back to ``tools/.md``. + """ + if tool_name in _MEMORY_VARIANT_TOOLS: + return f"tools/{tool_name}_{variant}.md" + return f"tools/{tool_name}.md" + + +def _example_fragment_path(tool_name: str, variant: str) -> str: + if tool_name in _MEMORY_VARIANT_TOOLS: + return f"examples/{tool_name}_{variant}.md" + return f"examples/{tool_name}.md" + + +def _format_tool_label(tool_name: str) -> str: + return tool_name.replace("_", " ").title() + + +# ----------------------------------------------------------------------------- +# Section builders +# ----------------------------------------------------------------------------- + + +def _build_system_instructions( + *, + visibility: ChatVisibility, + resolved_today: str, +) -> str: + """Reconstruct the legacy ```` block from fragments.""" + variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private" + + sections = [ + _read_fragment(f"base/agent_{variant}.md"), + _read_fragment(f"base/kb_only_policy_{variant}.md"), + _read_fragment(f"base/tool_routing_{variant}.md"), + _read_fragment("base/parameter_resolution.md"), + _read_fragment(f"base/memory_protocol_{variant}.md"), + ] + body = "\n\n".join(s for s in sections if s) + block = f"\n\n{body}\n\n\n" + return block.format(resolved_today=resolved_today) + + +def _build_mcp_routing_block( + mcp_connector_tools: dict[str, list[str]] | None, +) -> str: + """Emit the ```` block when at least one MCP server is wired.""" + if not mcp_connector_tools: + return "" + lines: list[str] = [ + "\n", + "You also have direct tools from these user-connected MCP servers.", + "Their data is NEVER in the knowledge base — call their tools directly.", + "", + ] + for server_name, tool_names in mcp_connector_tools.items(): + lines.append(f"- {server_name} → {', '.join(tool_names)}") + lines.append("\n") + return "\n".join(lines) + + +def _build_tools_section( + *, + visibility: ChatVisibility, + enabled_tool_names: set[str] | None, + disabled_tool_names: set[str] | None, +) -> str: + """Reconstruct the ```` block + ```` block.""" + variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private" + + parts: list[str] = [] + preamble = _read_fragment("tools/_preamble.md") + if preamble: + parts.append(preamble + "\n") + + examples: list[str] = [] + + for tool_name in ALL_TOOL_NAMES_ORDERED: + if enabled_tool_names is not None and tool_name not in enabled_tool_names: + continue + + instruction = _read_fragment(_tool_fragment_path(tool_name, variant)) + if instruction: + parts.append(instruction + "\n") + + example = _read_fragment(_example_fragment_path(tool_name, variant)) + if example: + examples.append(example + "\n") + + known_disabled = ( + set(disabled_tool_names) & set(ALL_TOOL_NAMES_ORDERED) + if disabled_tool_names + else set() + ) + if known_disabled: + disabled_list = ", ".join( + _format_tool_label(n) + for n in ALL_TOOL_NAMES_ORDERED + if n in known_disabled + ) + parts.append( + "\n" + "DISABLED TOOLS (by user):\n" + f"The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}.\n" + "You do NOT have access to these tools and MUST NOT claim you can use them.\n" + "If the user asks about a capability provided by a disabled tool, let them know the relevant tool\n" + "is currently disabled and they can re-enable it.\n" + ) + + parts.append("\n\n") + + if examples: + parts.append("") + parts.extend(examples) + parts.append("\n") + + return "".join(parts) + + +def _build_provider_block(provider_variant: ProviderVariant) -> str: + """Optional provider-tuned hints. Empty for ``"default"``.""" + if not provider_variant or provider_variant == "default": + return "" + text = _read_fragment(f"providers/{provider_variant}.md") + return f"\n{text}\n" if text else "" + + +def _build_routing_block(connector_routing: Iterable[str] | None) -> str: + if not connector_routing: + return "" + fragments: list[str] = [] + for name in connector_routing: + text = _read_fragment(f"routing/{name}.md") + if text: + fragments.append(text) + if not fragments: + return "" + return "\n" + "\n\n".join(fragments) + "\n" + + +def _build_citation_block(citations_enabled: bool) -> str: + fragment = ( + _read_fragment("base/citations_on.md") + if citations_enabled + else _read_fragment("base/citations_off.md") + ) + return f"\n{fragment}\n" if fragment else "" + + +# ----------------------------------------------------------------------------- +# Public API +# ----------------------------------------------------------------------------- + + +def compose_system_prompt( + *, + today: datetime | None = None, + thread_visibility: ChatVisibility | None = None, + enabled_tool_names: set[str] | None = None, + disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, + custom_system_instructions: str | None = None, + use_default_system_instructions: bool = True, + citations_enabled: bool = True, + provider_variant: ProviderVariant | None = None, + model_name: str | None = None, + connector_routing: Iterable[str] | None = None, +) -> str: + """Assemble the SurfSense system prompt from disk fragments. + + Args: + today: Optional clock injection for tests. + thread_visibility: Private vs shared (team) — drives memory wording + and a few base block variants. + enabled_tool_names: When provided, only these tools' instructions + are included; ``None`` keeps the legacy "include everything" + behavior. + disabled_tool_names: User-disabled tools (note appended to prompt). + mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject + an explicit MCP routing block. + custom_system_instructions: Free-form instructions that override + the default ```` block (legacy support + for ``NewLLMConfig.system_instructions``). + use_default_system_instructions: When ``custom_system_instructions`` + is empty/None, fall back to defaults (legacy semantics). + citations_enabled: Include ``citations_on.md`` (true) or + ``citations_off.md`` (false). + provider_variant: Explicit provider variant override + (``"anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default"``). + When ``None``, falls back to :func:`detect_provider_variant` + on ``model_name``. + model_name: Used to auto-detect ``provider_variant`` when not + provided explicitly. + connector_routing: Optional list of routing fragment names + (``["linear", "slack", ...]``) to include from + ``prompts/routing/``. + + Returns: + The fully composed system prompt string. + """ + resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() + visibility = thread_visibility or ChatVisibility.PRIVATE + + if custom_system_instructions and custom_system_instructions.strip(): + sys_block = custom_system_instructions.format(resolved_today=resolved_today) + elif use_default_system_instructions: + sys_block = _build_system_instructions( + visibility=visibility, resolved_today=resolved_today + ) + else: + sys_block = "" + + sys_block += _build_mcp_routing_block(mcp_connector_tools) + + if provider_variant is None: + provider_variant = detect_provider_variant(model_name) + sys_block += _build_provider_block(provider_variant) + sys_block += _build_routing_block(connector_routing) + + tools_block = _build_tools_section( + visibility=visibility, + enabled_tool_names=enabled_tool_names, + disabled_tool_names=disabled_tool_names, + ) + citation_block = _build_citation_block(citations_enabled) + + return sys_block + tools_block + citation_block + + +__all__ = [ + "ALL_TOOL_NAMES_ORDERED", + "ProviderVariant", + "compose_system_prompt", + "detect_provider_variant", +] diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md new file mode 100644 index 000000000..216c2926a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md @@ -0,0 +1,12 @@ + +- User: "Generate an image of a cat" + - Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` + - The generated image will automatically be displayed in the chat. +- User: "Draw me a logo for a coffee shop called Bean Dream" + - Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` + - The generated image will automatically be displayed in the chat. +- User: "Show me this image: https://example.com/image.png" + - Simply include it in your response using markdown: `![Image](https://example.com/image.png)` +- User uploads an image file and asks: "What is this image about?" + - The user's uploaded image is already visible in the chat. + - Simply analyze the image content and respond directly. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md new file mode 100644 index 000000000..aabf8ce7a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md @@ -0,0 +1,7 @@ + +- User: "Give me a podcast about AI trends based on what we discussed" + - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` +- User: "Create a podcast summary of this conversation" + - Call: `generate_podcast(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")` +- User: "Make a podcast about quantum computing" + - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md new file mode 100644 index 000000000..7e9d0a595 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md @@ -0,0 +1,13 @@ + +- User: "Generate a report about AI trends" + - Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")` + - WHY: Has creation verb "generate" → call the tool. No prior discussion → use kb_search. +- User: "Write a research report from this conversation" + - Call: `generate_report(topic="Research Report", source_strategy="conversation", source_content="Complete conversation summary:\n\n...", report_style="deep_research")` + - WHY: Has creation verb "write" → call the tool. Conversation has the content → use source_strategy="conversation". +- User: (after a report on Climate Change was generated) "Add a section about carbon capture technologies" + - Call: `generate_report(topic="Climate Crisis: Causes, Impacts, and Solutions", source_strategy="conversation", source_content="[summary of conversation context if any]", parent_report_id=, user_instructions="Add a new section about carbon capture technologies")` + - WHY: Has modification verb "add" + specific deliverable target → call the tool with parent_report_id. +- User: (after a report was generated) "What else could we add to have more depth?" + - Do NOT call generate_report. Answer in chat with suggestions. + - WHY: No creation/modification verb directed at producing a deliverable. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md new file mode 100644 index 000000000..d8a6c381e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md @@ -0,0 +1,19 @@ + +- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..." + - Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)` + - WHY: Has creation verb "build" + resume → call the tool. +- User: "Create my CV with this info: [experience, education, skills]" + - Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)` +- User: "Build me a resume" (and there is a resume/CV document in the conversation context) + - Extract the FULL content from the document in context, then call: + `generate_resume(user_info="Name: John Doe\nEmail: john@example.com\n\nExperience:\n- Senior Engineer at Acme Corp (2020-2024)\n Led team of 5...\n\nEducation:\n- BS Computer Science, MIT (2016-2020)\n\nSkills: Python, TypeScript, AWS...", max_pages=1)` + - WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents. +- User: (after resume generated) "Change my title to Senior Engineer" + - Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=, max_pages=1)` + - WHY: Modification verb "change" + refers to existing resume → set parent_report_id. +- User: (after resume generated) "Make this 2 pages and expand projects" + - Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=, max_pages=2)` + - WHY: Explicit page increase request → set max_pages to 2. +- User: "How should I structure my resume?" + - Do NOT call generate_resume. Answer in chat with advice. + - WHY: No creation/modification verb. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md new file mode 100644 index 000000000..257ec86cf --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md @@ -0,0 +1,7 @@ + +- User: "Give me a presentation about AI trends based on what we discussed" + - First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")` +- User: "Create slides summarizing this conversation" + - Call: `generate_video_presentation(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")` +- User: "Make a video presentation about quantum computing" + - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md b/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md new file mode 100644 index 000000000..0f156bf24 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md @@ -0,0 +1,13 @@ + +- User: "Check out https://dev.to/some-article" + - Call: `scrape_webpage(url="https://dev.to/some-article")` + - Respond with a structured analysis — key points, takeaways. +- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends" + - Call: `scrape_webpage(url="https://example.com/blog/ai-trends")` + - Respond with a thorough summary using headings and bullet points. +- User: (after discussing https://example.com/stats) "Can you get the live data from that page?" + - Call: `scrape_webpage(url="https://example.com/stats")` + - IMPORTANT: Always attempt scraping first. Never refuse before trying the tool. +- User: "https://example.com/blog/weekend-recipes" + - Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")` + - When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md b/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md new file mode 100644 index 000000000..b90f2b7a7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md @@ -0,0 +1,9 @@ + +- User: "How do I install SurfSense?" + - Call: `search_surfsense_docs(query="installation setup")` +- User: "What connectors does SurfSense support?" + - Call: `search_surfsense_docs(query="available connectors integrations")` +- User: "How do I set up the Notion connector?" + - Call: `search_surfsense_docs(query="Notion connector setup configuration")` +- User: "How do I use Docker to run SurfSense?" + - Call: `search_surfsense_docs(query="Docker installation setup")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md new file mode 100644 index 000000000..f83fe40b4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md @@ -0,0 +1,16 @@ + +- Alex, is empty. User: "I'm a space enthusiast, explain astrophage to me" + - The user casually shared a durable fact. Use their first name in the entry, short neutral heading: + update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n") +- User: "Remember that I prefer concise answers over detailed explanations" + - Durable preference. Merge with existing memory, add a new heading: + update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n") +- User: "I actually moved to Tokyo last month" + - Updated fact, date prefix reflects when recorded: + update_memory(updated_memory="## Interests & background\n...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...") +- User: "I'm a freelance photographer working on a nature documentary" + - Durable background info under a fitting heading: + update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n") +- User: "Always respond in bullet points" + - Standing instruction: + update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n") diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md new file mode 100644 index 000000000..1c74fdf6e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md @@ -0,0 +1,7 @@ + +- User: "Let's remember that we decided to do weekly standup meetings on Mondays" + - Durable team decision: + update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...") +- User: "Our office is in downtown Seattle, 5th floor" + - Durable team fact: + update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...") diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md b/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md new file mode 100644 index 000000000..6b9828ac7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md @@ -0,0 +1,8 @@ + +- User: "What's the current USD to INR exchange rate?" + - Call: `web_search(query="current USD to INR exchange rate")` + - Then answer using the returned web results with citations. +- User: "What's the latest news about AI?" + - Call: `web_search(query="latest AI news today")` +- User: "What's the weather in New York?" + - Call: `web_search(query="weather New York today")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md b/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md new file mode 100644 index 000000000..6e22ef265 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md @@ -0,0 +1,5 @@ + +You are running on an Anthropic Claude model. Use XML tags liberally to structure +intermediate reasoning when the task is complex. Prefer step-by-step plans inside +`` blocks before producing the final answer. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/default.md b/surfsense_backend/app/agents/new_chat/prompts/providers/default.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/default.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/google.md b/surfsense_backend/app/agents/new_chat/prompts/providers/google.md new file mode 100644 index 000000000..4b31a8388 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/google.md @@ -0,0 +1,4 @@ + +You are running on a Google Gemini model. Prefer concise, structured responses. +When using tools, follow the function-calling protocol and avoid verbose preludes. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md new file mode 100644 index 000000000..7ea4366c4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md @@ -0,0 +1,5 @@ + +You are running on a classic OpenAI chat model (GPT-4 family). Use direct +function-calling for tools. When editing files, use the standard `edit_file` +or `write_file` tools rather than diff-based patches. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md new file mode 100644 index 000000000..935d3f207 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md @@ -0,0 +1,5 @@ + +You are running on an OpenAI reasoning model (o-series / GPT-5+). Be terse and +direct in your responses. When editing files, prefer the `apply_patch` tool format +where available. Avoid restating the user request before answering. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md b/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md b/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md b/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md b/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md new file mode 100644 index 000000000..2c169e015 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md @@ -0,0 +1,6 @@ + +You have access to the following tools: + +IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it. +Do NOT claim you can do something if the corresponding tool is not listed. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md new file mode 100644 index 000000000..8bde13f22 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md @@ -0,0 +1,11 @@ + +- generate_image: Generate images from text descriptions using AI image models. + - Use this when the user asks you to create, generate, draw, design, or make an image. + - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" + - Args: + - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. + - n: Number of images to generate (1-4, default: 1) + - Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat. + - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - + expand and improve the prompt with specific details about style, lighting, composition, and mood. + - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md new file mode 100644 index 000000000..58be143d7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md @@ -0,0 +1,15 @@ + +- generate_podcast: Generate an audio podcast from provided content. + - Use this when the user asks to create, generate, or make a podcast. + - Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast" + - Args: + - source_content: The text content to convert into a podcast. This MUST be comprehensive and include: + * If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses) + * If based on knowledge base search: Include the key findings and insights from the search results + * You can combine both: conversation context + search results for richer podcasts + * The more detailed the source_content, the better the podcast quality + - podcast_title: Optional title for the podcast (default: "SurfSense Podcast") + - user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun") + - Returns: A task_id for tracking. The podcast will be generated in the background. + - IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating". + - After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes). diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md new file mode 100644 index 000000000..8a285a433 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md @@ -0,0 +1,39 @@ + +- generate_report: Generate or revise a structured Markdown report artifact. + - WHEN TO CALL THIS TOOL — the message must contain a creation or modification VERB directed at producing a deliverable: + * Creation verbs: write, create, generate, draft, produce, summarize into, turn into, make + * Modification verbs: revise, update, expand, add (a section), rewrite, make (it shorter/longer/formal) + * Example triggers: "generate a report about...", "write a document on...", "add a section about budget", "make the report shorter", "rewrite in formal tone" + - WHEN NOT TO CALL THIS TOOL (answer in chat instead): + * Questions or discussion about the report: "What can we add?", "What's missing?", "Is the data accurate?", "How could this be improved?" + * Suggestions or brainstorming: "What other topics could be covered?", "What else could be added?", "What would make this better?" + * Asking for explanations: "Can you explain section 2?", "Why did you include that?", "What does this part mean?" + * Quick follow-ups or critiques: "Is the conclusion strong enough?", "Are there any gaps?", "What about the competitors?" + * THE TEST: Does the message contain a creation/modification VERB (from the list above) directed at producing or changing a deliverable? If NO verb → answer conversationally in chat. Do NOT assume the user wants a revision just because a report exists in the conversation. + - IMPORTANT FORMAT RULE: Reports are ALWAYS generated in Markdown. + - Args: + - topic: Short title for the report (max ~8 words). + - source_content: The text content to base the report on. + * For source_strategy="conversation" or "provided": Include a comprehensive summary of the relevant content. + * For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally. + * For source_strategy="auto": Include what you have; the tool searches KB if it's not enough. + - source_strategy: Controls how the tool collects source material. One of: + * "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. + * "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. + * "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries. + * "provided" — Use only what is in source_content (default, backward-compatible). + - search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated. + - report_style: Controls report depth. Options: "detailed" (DEFAULT), "deep_research", "brief". + Use "brief" ONLY when the user explicitly asks for a short/concise/one-page report (e.g., "one page", "keep it short", "brief report", "500 words"). Default to "detailed" for all other requests. + - user_instructions: Optional specific instructions (e.g., "focus on financial impacts", "include recommendations"). When revising (parent_report_id set), describe WHAT TO CHANGE. If the user mentions a length preference (e.g., "one page", "500 words", "2 pages"), include that VERBATIM here AND set report_style="brief". + - parent_report_id: Set this to the report_id from a previous generate_report result when the user wants to MODIFY an existing report. Do NOT set it for new reports or questions about reports. + - Returns: A dictionary with status "ready" or "failed", report_id, title, and word_count. + - The report is generated immediately in Markdown and displayed inline in the chat. + - Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report. + - SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly): + * If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content. + * If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries. + * If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries. + * When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content. + * NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally. + - AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md new file mode 100644 index 000000000..321ea90c9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md @@ -0,0 +1,30 @@ + +- generate_resume: Generate or revise a professional resume as a Typst document. + - WHEN TO CALL: The user asks to create, build, generate, write, or draft a resume or CV. + Also when they ask to modify, update, or revise an existing resume from this conversation. + - WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing + a resume without making changes. For cover letters, use generate_report instead. + - The tool produces Typst source code that is compiled to a PDF preview automatically. + - PAGE POLICY: + - Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more. + - If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value. + - Args: + - user_info: The user's resume content — work experience, education, skills, contact + info, etc. Can be structured or unstructured text. + CRITICAL: user_info must be COMPREHENSIVE. Do NOT just pass the user's raw message. + You MUST gather and consolidate ALL available information: + * Content from referenced/mentioned documents (e.g., uploaded resumes, CVs, LinkedIn profiles) + that appear in the conversation context — extract and include their FULL content. + * Information the user shared across multiple messages in the conversation. + * Any relevant details from knowledge base search results in the context. + The more complete the user_info, the better the resume. Include names, contact info, + work experience with dates, education, skills, projects, certifications — everything available. + - user_instructions: Optional style or content preferences (e.g. "emphasize leadership", + "keep it to one page"). For revisions, describe what to change. + - parent_report_id: Set this when the user wants to MODIFY an existing resume from + this conversation. Use the report_id from a previous generate_resume result. + - max_pages: Maximum resume length in pages (integer 1-5). Default is 1. + - Returns: Dict with status, report_id, title, and content_type. + - After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically. + - VERSIONING: Same rules as generate_report — set parent_report_id for modifications + of an existing resume, leave as None for new resumes. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md new file mode 100644 index 000000000..c3def88f2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md @@ -0,0 +1,9 @@ + +- generate_video_presentation: Generate a video presentation from provided content. + - Use this when the user asks to create a video, presentation, slides, or slide deck. + - Trigger phrases: "give me a presentation", "create slides", "generate a video", "make a slide deck", "turn this into a presentation" + - Args: + - source_content: The text content to turn into a presentation. The more detailed, the better. + - video_title: Optional title (default: "SurfSense Presentation") + - user_prompt: Optional style instructions (e.g., "Make it technical and detailed") + - After calling this tool, inform the user that generation has started and they will see the presentation when it's ready. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md b/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md new file mode 100644 index 000000000..46e299392 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md @@ -0,0 +1,30 @@ + +- scrape_webpage: Scrape and extract the main content from a webpage. + - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. + - CRITICAL — WHEN TO USE (always attempt scraping, never refuse before trying): + * When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL + * When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices) + * When a URL was mentioned earlier in the conversation and the user asks for its actual content + * When `/documents/` knowledge-base data is insufficient and the user wants more + - Trigger scenarios: + * "Read this article and summarize it" + * "What does this page say about X?" + * "Summarize this blog post for me" + * "Tell me the key points from this article" + * "What's in this webpage?" + * "Can you analyze this article?" + * "Can you get the live table/data from [URL]?" + * "Scrape it" / "Can you scrape that?" (referring to a previously mentioned URL) + * "Fetch the content from [URL]" + * "Pull the data from that page" + - Args: + - url: The URL of the webpage to scrape (must be HTTP/HTTPS) + - max_length: Maximum content length to return (default: 50000 chars) + - Returns: The page title, description, full content (in markdown), word count, and metadata + - After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points. + - Reference the source using markdown links [descriptive text](url) — never bare URLs. + - IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`. + * When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: `![alt text](image_url)`. + * This makes your response more visual and engaging. + * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. + * Don't show every image - just the most relevant 1-3 images that enhance understanding. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md b/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md new file mode 100644 index 000000000..133717fec --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md @@ -0,0 +1,7 @@ + +- search_surfsense_docs: Search the official SurfSense documentation. + - Use this tool when the user asks anything about SurfSense itself (the application they are using). + - Args: + - query: The search query about SurfSense + - top_k: Number of documentation chunks to retrieve (default: 10) + - Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123]) diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md new file mode 100644 index 000000000..184013804 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md @@ -0,0 +1,31 @@ + +- update_memory: Update your personal memory document about the user. + - Your current memory is already in in your context. The `chars` and + `limit` attributes show your current usage and the maximum allowed size. + - This is your curated long-term memory — the distilled essence of what you know about + the user, not raw conversation logs. + - Call update_memory when: + * The user explicitly asks to remember or forget something + * The user shares durable facts or preferences that will matter in future conversations + - The user's first name is provided in . Use it in memory entries + instead of "the user" (e.g. "{name} works at..." not "The user works at..."). + Do not store the name itself as a separate memory entry. + - Do not store short-lived or ephemeral info: one-off questions, greetings, + session logistics, or things that only matter for the current task. + - Args: + - updated_memory: The FULL updated markdown document (not a diff). + Merge new facts with existing ones, update contradictions, remove outdated entries. + Treat every update as a curation pass — consolidate, don't just append. + - Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text + Markers: + [fact] — durable facts (role, background, projects, tools, expertise) + [pref] — preferences (response style, languages, formats, tools) + [instr] — standing instructions (always/never do, response rules) + - Keep it concise and well under the character limit shown in . + - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and + natural. Do NOT include the user's name in headings. Organize by context — e.g. + who they are, what they're focused on, how they prefer things. Create, split, or + merge headings freely as the memory grows. + - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant + details and context rather than just a few words. + - During consolidation, prioritize keeping: [instr] > [pref] > [fact]. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md new file mode 100644 index 000000000..7eaca8818 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md @@ -0,0 +1,26 @@ + +- update_memory: Update the team's shared memory document for this search space. + - Your current team memory is already in in your context. The `chars` + and `limit` attributes show current usage and the maximum allowed size. + - This is the team's curated long-term memory — decisions, conventions, key facts. + - NEVER store personal memory in team memory (e.g. personal bio, individual + preferences, or user-only standing instructions). + - Call update_memory when: + * A team member explicitly asks to remember or forget something + * The conversation surfaces durable team decisions, conventions, or facts + that will matter in future conversations + - Do not store short-lived or ephemeral info: one-off questions, greetings, + session logistics, or things that only matter for the current task. + - Args: + - updated_memory: The FULL updated markdown document (not a diff). + Merge new facts with existing ones, update contradictions, remove outdated entries. + Treat every update as a curation pass — consolidate, don't just append. + - Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text + Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory. + - Keep it concise and well under the character limit shown in . + - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and + natural. Organize by context — e.g. what the team decided, current architecture, + active processes. Create, split, or merge headings freely as the memory grows. + - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant + details and context rather than just a few words. + - During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md b/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md new file mode 100644 index 000000000..7ed7c332d --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md @@ -0,0 +1,18 @@ + +- web_search: Search the web for real-time information using all configured search engines. + - Use this for current events, news, prices, weather, public facts, or any question requiring + up-to-date information from the internet. + - This tool dispatches to all configured search engines (SearXNG, Tavily, Linkup, Baidu) in + parallel and merges the results. + - IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data + (e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call + `web_search` instead of answering from memory. + - For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet + access before attempting a web search. + - If the search returns no relevant results, explain that web sources did not return enough + data and ask the user if they want you to retry with a refined query. + - Args: + - query: The search query - use specific, descriptive terms + - top_k: Number of results to retrieve (default: 10, max: 50) + - If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content. + - When presenting results, reference sources as markdown links [descriptive text](url) — never bare URLs. diff --git a/surfsense_backend/app/agents/new_chat/skills/__init__.py b/surfsense_backend/app/agents/new_chat/skills/__init__.py new file mode 100644 index 000000000..bb7ac055c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/__init__.py @@ -0,0 +1,7 @@ +"""SurfSense built-in agent skills (Anthropic Skills format). + +Each subdirectory corresponds to one skill and contains a ``SKILL.md`` file +with YAML frontmatter (name, description, allowed_tools) plus markdown +instructions. The :class:`BuiltinSkillsBackend` exposes them to the +deepagents :class:`SkillsMiddleware`. +""" diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py b/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md new file mode 100644 index 000000000..32e599e98 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md @@ -0,0 +1,25 @@ +--- +name: email-drafting +description: Draft an email matching the user's voice, with structured intent and CTA +allowed-tools: search_surfsense_docs +--- + +# Email drafting + +## When to use this skill +"Draft an email to ...", "reply to this thread", "write a follow-up to X". Plain "summarize the email" is **not** in scope — that's a comprehension task. + +## Voice +Search the KB for prior emails from the user to similar audiences (same recipient, same topic class). Mirror tone, opening style, sign-off, and length distribution. If there is no precedent, default to: warm, direct, no filler, short paragraphs, one clear ask. + +## Required structure +Every draft includes, in this order: + +1. **Subject line** — concrete, ≤ 8 words, no clickbait, no `Re:` unless replying. +2. **Opening (1 sentence)** — context the recipient already shares; never restate what they wrote unless the thread is long. +3. **Body** — the actual point in one short paragraph. Bullets only if there are >3 discrete items. +4. **Single explicit CTA** — what you want the recipient to do, with a soft deadline if relevant. +5. **Sign-off** — match the user's prior closing style. + +## Always offer alternatives +End your message with: "Want me to make it shorter, more formal, or add a different angle?" — give the user one obvious next step. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md new file mode 100644 index 000000000..c268278ab --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md @@ -0,0 +1,23 @@ +--- +name: kb-research +description: Structured approach to finding and synthesizing information from the user's knowledge base +allowed-tools: search_surfsense_docs, scrape_webpage, read_file, ls_tree, grep, web_search +--- + +# Knowledge-base research + +## When to use this skill +- The user asks "find/look up/research" something specifically inside their knowledge base. +- The user references documents, notes, repos, or connector data they expect to exist already. +- A multi-document synthesis is required (e.g., "summarize what we've discussed about X across all my notes"). + +## Plan +1. Decompose the user's question into 2-4 specific, citation-worthy sub-questions. +2. For each sub-question, run **one** targeted KB search (focused on terms the user would have written, not synonyms). Open the most relevant 2-3 documents fully via `read_file` if their excerpts are too short. +3. Use `grep` to find supporting passages in long files instead of re-reading them end to end. +4. Cite every claim with `[citation:chunk_id]` exactly as the chunk tag specifies. + +## What good output looks like +- Short paragraphs with inline citations. +- Quoted phrases when wording matters. +- An explicit "Not found in your knowledge base" callout when a sub-question has no support — never fabricate. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md new file mode 100644 index 000000000..9657eb078 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md @@ -0,0 +1,22 @@ +--- +name: meeting-prep +description: Pull together briefing materials before a scheduled meeting +allowed-tools: search_surfsense_docs, web_search, scrape_webpage, read_file +--- + +# Meeting preparation + +## When to use this skill +The user mentions an upcoming meeting, call, or interview and asks you to "prep", "brief me", "pull background", or "what do I need to know about X before tomorrow". + +## Output structure +Always produce these sections (omit any with no signal — don't pad): + +1. **Attendees & context** — who's in the room, their roles, what they care about. Pull from KB notes about prior interactions; supplement with public profile facts via `web_search` when names or companies are unfamiliar. +2. **Open threads** — outstanding action items, unresolved decisions, last-mentioned blockers from prior conversation history. +3. **Recent moves** — within the last 30 days: relevant launches, hires, news. Cite KB chunks when present, otherwise external sources. +4. **Suggested questions** — 3-5 questions the user could ask, tailored to the open threads and the attendees' likely priorities. + +## Source ordering +- Always check the user's KB **first** for prior meeting notes, internal docs, or Slack threads about these attendees. +- Only fall back to `web_search` for *publicly verifiable* facts — never to fabricate a participant's preferences or relationships. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md new file mode 100644 index 000000000..17ac2f391 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md @@ -0,0 +1,23 @@ +--- +name: report-writing +description: How to scope, draft, and revise a Markdown report artifact via generate_report +allowed-tools: generate_report, search_surfsense_docs, read_file +--- + +# Report writing + +## When to use this skill +The user explicitly requests a deliverable: "write a report on …", "draft a memo", "produce a brief", "expand the previous report". A creation or modification verb pointed at an artifact is required (see `generate_report`'s when-to-call rules). + +## Decision flow +1. **Source strategy.** Decide which `source_strategy` fits: + - `conversation` — substantive Q&A on the topic already in chat. + - `kb_search` — fresh topic; supply 1–5 precise `search_queries`. + - `auto` — partial conversation context; let the tool fall back. + - `provided` — verbatim source text only. +2. **Style.** Default to `report_style="detailed"` unless the user explicitly asks for "brief", "one page", "500 words". +3. **Revisions.** When modifying an existing report from this conversation, set `parent_report_id` and put the change list in `user_instructions` ("add carbon-capture section", "tighten conclusion"). +4. **Never paste the report back into chat** after `generate_report` returns — confirm and let the artifact card render itself. + +## Hooks for KB-only mode +If `kb_search`/`auto` returns no results, do **not** silently switch to general knowledge. Surface the gap in your confirmation message. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md new file mode 100644 index 000000000..33b9e72a2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md @@ -0,0 +1,26 @@ +--- +name: slack-summary +description: Distill a Slack channel or thread into actionable summary +allowed-tools: search_surfsense_docs +--- + +# Slack summarization + +## When to use this skill +The user asks to summarize Slack ("what happened in #eng-platform this week", "what did Alice say about the launch", "catch me up on the design channel"). + +## Required inputs +Confirm before searching: +- **Which channel(s) or thread(s)?** Don't guess if ambiguous. +- **What time window?** Default to the last 7 days when not specified, but say so. + +## Output shape +Produce three concise sections: +1. **Key decisions** — explicit choices that were made, with the deciding message cited. +2. **Open questions** — things asked but not answered, with the asking message cited. +3. **Action items** — `@mention` who owes what by when, *only if explicitly stated*. Don't invent assignees. + +## What not to do +- Never produce a chronological play-by-play of every message — distill. +- Never quote private messages without flagging them as such. +- If the channel was empty in the time window, say so — don't fabricate filler. diff --git a/surfsense_backend/app/agents/new_chat/subagents/__init__.py b/surfsense_backend/app/agents/new_chat/subagents/__init__.py new file mode 100644 index 000000000..b9f21a0d2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/subagents/__init__.py @@ -0,0 +1,26 @@ +"""Specialized user-facing subagents for the SurfSense agent. + +Each subagent is a :class:`deepagents.SubAgent` typed-dict spec passed to +:class:`deepagents.SubAgentMiddleware`, which materializes them as ephemeral +runnables invoked via the ``task`` tool. + +Per-subagent permission rules are injected as a +:class:`PermissionMiddleware` entry inside the subagent's ``middleware`` +field, mirroring opencode ``tool/task.ts`` which seeds child sessions with +deny rules for tools the parent does not want them touching (e.g. +``task``/``todowrite`` recursion, write tools for read-only research roles). +""" + +from .config import ( + build_connector_negotiator_subagent, + build_explore_subagent, + build_report_writer_subagent, + build_specialized_subagents, +) + +__all__ = [ + "build_connector_negotiator_subagent", + "build_explore_subagent", + "build_report_writer_subagent", + "build_specialized_subagents", +] diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py new file mode 100644 index 000000000..e20bc06bf --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -0,0 +1,427 @@ +"""Builders for specialized SurfSense subagents. + +Each subagent is built from three pieces: + +1. A name + description + system prompt (the user-facing contract for + when ``task`` should delegate to this role). +2. A filtered tool list (subset of the parent's bound tools). +3. A :class:`PermissionMiddleware` instance carrying a deny ruleset that + prevents the subagent from acting outside its scope (e.g. an + explore-only role cannot mutate state). + +Skill sources (``/skills/builtin/`` + ``/skills/space/``) are inherited +from the parent unconditionally — every subagent benefits from the same +authored guidance documents. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any + +from app.agents.new_chat.middleware.skills_backends import default_skills_sources +from app.agents.new_chat.permissions import Rule, Ruleset + +if TYPE_CHECKING: + from deepagents import SubAgent + from langchain_core.language_models import BaseChatModel + from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Tool name constants +# --------------------------------------------------------------------------- + +# Read-only tools that ``explore`` is permitted to use. Names match the +# tools provided by the deepagents ``FilesystemMiddleware`` (``ls``, ``read_file``, +# ``glob``, ``grep``) plus the SurfSense-side read tools. +EXPLORE_READ_TOOLS: frozenset[str] = frozenset( + { + "search_surfsense_docs", + "web_search", + "scrape_webpage", + "read_file", + "ls", + "glob", + "grep", + } +) + +# Tools ``report_writer`` may call. The set is intentionally narrow so the +# subagent doesn't drift into tangential research; if richer source-gathering +# is needed, the parent should hand off to ``explore`` first. +REPORT_WRITER_TOOLS: frozenset[str] = frozenset( + { + "search_surfsense_docs", + "read_file", + "generate_report", + } +) + +# Wildcard patterns that match write tools we deny by default in read-only +# subagents. Anchored at start AND end via :func:`Rule` semantics. We use +# substring-style ``*verb*`` patterns because connector tool names typically +# put the verb in the middle (``linear_create_issue``, ``slack_send_message``, +# ``notion_update_page``); strict suffix patterns (``*_create``) miss those. +# +# A handful of canonical exact-match names is appended so that bare verbs +# (``edit``, ``write``) are also blocked even when a connector dropped the +# usual prefix. +WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( + "*create*", + "*update*", + "*delete*", + "*send*", + "*write*", + "*edit*", + "*move*", + "*mkdir*", + "*upload*", + "edit_file", + "write_file", + "move_file", + "mkdir", + "update_memory", + "update_memory_team", + "update_memory_private", +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +# Tool names that are NOT in the registry's ``tools`` list because they +# are provided dynamically by middleware at compile time. We don't pass +# them through ``_filter_tools`` (the actual ``BaseTool`` instances live +# inside the middleware), but we do exempt them from the "missing" warning +# below — operators were seeing spurious noise like +# ``missing: ['glob', 'grep', 'ls', 'read_file']`` even though those +# tools are reachable via :class:`SurfSenseFilesystemMiddleware` once the +# subagent is compiled. +_MIDDLEWARE_PROVIDED_TOOL_NAMES: frozenset[str] = frozenset( + { + "ls", + "read_file", + "write_file", + "edit_file", + "glob", + "grep", + "execute", + "write_todos", + "task", + } +) + + +def _filter_tools( + tools: Sequence[BaseTool], + allowed_names: Iterable[str], +) -> list[BaseTool]: + """Return only tools whose ``name`` appears in ``allowed_names``. + + Tools are looked up by exact name. Names matching + :data:`_MIDDLEWARE_PROVIDED_TOOL_NAMES` are intentionally absent from + ``tools`` (they're injected by middleware at compile time) and are + silently excluded from the "missing" warning so operators don't see + false positives every build. + """ + allowed = set(allowed_names) + selected = [t for t in tools if t.name in allowed] + missing = sorted( + (allowed - {t.name for t in selected}) - _MIDDLEWARE_PROVIDED_TOOL_NAMES + ) + if missing: + logger.info( + "Subagent build: %d/%d registry tools available; missing: %s", + len(selected), + len(allowed - _MIDDLEWARE_PROVIDED_TOOL_NAMES), + missing, + ) + return selected + + +def _read_only_deny_rules() -> list[Rule]: + """Synthesize a list of deny rules covering common write-tool patterns.""" + return [ + Rule(permission=pattern, pattern="*", action="deny") + for pattern in WRITE_TOOL_DENY_PATTERNS + ] + + +def _build_permission_middleware(deny_rules: list[Rule], origin: str): + """Construct a :class:`PermissionMiddleware` seeded with ``deny_rules``. + + Imported lazily because the middleware module pulls in interrupt/HITL + machinery we don't want at import time of this config file. + """ + from app.agents.new_chat.middleware.permission import PermissionMiddleware + + return PermissionMiddleware( + rulesets=[Ruleset(rules=deny_rules, origin=origin)], + ) + + +def _wrap_with_subagent_essentials( + custom_middleware: list, + *, + agent_tools: Sequence[BaseTool], + extra_middleware: Sequence[Any] | None = None, +): + """Compose the final middleware list for a specialized subagent. + + Order, outer to inner: + + 1. ``extra_middleware`` — provided by the caller (typically the parent + agent's ``SurfSenseFilesystemMiddleware`` and ``TodoListMiddleware``) + so the subagent inherits the parent's filesystem/todo view. These + run **before** the subagent-local middleware so their tools are + wired up before permissioning kicks in. + 2. ``custom_middleware`` — subagent-local rules (e.g. permission deny + lists). + 3. :class:`PatchToolCallsMiddleware` — normalizes tool-call shapes. + 4. :class:`DedupHITLToolCallsMiddleware` — collapses duplicate HITL + calls using metadata declared at registry time. + + Without ``extra_middleware`` the subagent will only have the registry + tools listed in its ``tools`` field — meaning ``read_file``, ``ls``, + ``grep``, etc. won't exist. Always pass ``extra_middleware`` from the + parent unless you specifically want a sandboxed subagent. + """ + from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware + + from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware + + return [ + *(extra_middleware or []), + *custom_middleware, + PatchToolCallsMiddleware(), + DedupHITLToolCallsMiddleware(agent_tools=list(agent_tools)), + ] + + +# --------------------------------------------------------------------------- +# System prompts +# --------------------------------------------------------------------------- + +EXPLORE_SYSTEM_PROMPT = """You are the **explore** subagent for SurfSense. + +## Your job +Conduct read-only research across the user's knowledge base, the web, and any documents the parent agent has surfaced. Return a synthesized answer with explicit citations — never speculate beyond the sources you have actually inspected. + +## Tools available +- `search_surfsense_docs` — fast hybrid search over the user's knowledge base. +- `web_search` — only when the user's KB clearly does not contain the answer. +- `scrape_webpage` — to read a URL the user or the search results provided. +- `read_file`, `ls`, `glob`, `grep` — to inspect specific documents or trees the parent has flagged. + +## Rules +- Read-only. You cannot create, edit, delete, send, or move anything. +- Cite every claim. Use `[citation:chunk_id]` exactly as the chunk tag specifies. +- If a sub-question has no support in the inspected sources, say so explicitly. Do not fabricate. +- Return the most useful synthesis in your single final message. The parent agent will not be able to follow up. +""" + + +REPORT_WRITER_SYSTEM_PROMPT = """You are the **report_writer** subagent for SurfSense. + +## Your job +Produce a single high-quality report deliverable using `generate_report`. The parent has already gathered (or knows where to gather) the underlying sources. + +## Workflow +1. **Outline first.** Before calling `generate_report`, write a one-paragraph outline of the sections you plan to produce. Confirm the outline reflects the parent's instructions. +2. **Source resolution.** Decide whether to call `search_surfsense_docs` and `read_file` for any final-checks, or whether the parent's earlier tool calls already cover the source set. +3. **One report.** Call `generate_report` exactly once with `source_strategy` chosen per the topic and chat history (see the `report-writing` skill). +4. **Confirm.** End with a one-sentence summary in your final message — never paste the report back into chat; the artifact card renders itself. +""" + + +CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT = """You are the **connector_negotiator** subagent for SurfSense. + +## Your job +Coordinate cross-connector workflows: chains where the result of one service's tool feeds into another's. Common shapes include "find Linear issues mentioned in last week's Slack messages", "draft a Gmail reply citing a Notion doc", or "list Linear tickets opened by the same person who filed Jira FOO-123". + +## Workflow +1. **Plan.** Identify the connector hops needed and the order they should run in. Write a short plan in your first message. +2. **Verify access.** Use `get_connected_accounts` to confirm the relevant connectors are actually wired up before issuing tool calls. If a connector is missing, stop and report — do not fabricate. +3. **Execute.** Run each hop, citing IDs (issue keys, message ts, page IDs) in your scratch notes so the parent can audit. +4. **Hand back.** Return a structured summary with the final answer plus the chain of evidence (issue → message → page, etc.). + +## Caveats +- If a hop fails, do not retry blindly — return the partial result and explain. +- Mutating tools (create, update, delete, send) require parent permission; you are NOT cleared to call them on your own. +""" + + +# --------------------------------------------------------------------------- +# Subagent builders +# --------------------------------------------------------------------------- + + +def build_explore_subagent( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> SubAgent: + """Build the read-only ``explore`` subagent spec. + + Pass ``extra_middleware`` (typically the parent's filesystem + todo + middleware) so the subagent can actually use ``read_file``, ``ls``, + ``grep``, ``glob`` — which its system prompt promises but which only + exist when their middleware is mounted. + """ + from deepagents import SubAgent # noqa: F401 (TypedDict for type clarity) + + selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS) + deny_rules = _read_only_deny_rules() + permission_mw = _build_permission_middleware( + deny_rules, origin="subagent_explore" + ) + + spec: dict = { + "name": "explore", + "description": ( + "Read-only research across the user's knowledge base and the web. " + "Use when the parent needs deeply-cited synthesis without " + "modifying anything." + ), + "system_prompt": EXPLORE_SYSTEM_PROMPT, + "tools": selected_tools, + "middleware": _wrap_with_subagent_essentials( + [permission_mw], + agent_tools=selected_tools, + extra_middleware=extra_middleware, + ), + "skills": default_skills_sources(), + } + if model is not None: + spec["model"] = model + return spec # type: ignore[return-value] + + +def build_report_writer_subagent( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> SubAgent: + """Build the ``report_writer`` subagent spec. + + Read-only deny ruleset still applies — the subagent should call + ``generate_report`` and nothing else mutating. ``generate_report`` + creates a report artifact via a backend service and is intentionally + **not** denied. + + Pass ``extra_middleware`` (typically the parent's filesystem + todo + middleware) so the subagent can run ``read_file`` for source-checks + before calling ``generate_report``. + """ + selected_tools = _filter_tools(tools, REPORT_WRITER_TOOLS) + deny_rules = _read_only_deny_rules() + permission_mw = _build_permission_middleware( + deny_rules, origin="subagent_report_writer" + ) + + spec: dict = { + "name": "report_writer", + "description": ( + "Produce a single Markdown report artifact via generate_report, " + "using the outline-then-fill protocol. Use when the parent has " + "decided a deliverable is needed." + ), + "system_prompt": REPORT_WRITER_SYSTEM_PROMPT, + "tools": selected_tools, + "middleware": _wrap_with_subagent_essentials( + [permission_mw], + agent_tools=selected_tools, + extra_middleware=extra_middleware, + ), + "skills": default_skills_sources(), + } + if model is not None: + spec["model"] = model + return spec # type: ignore[return-value] + + +def build_connector_negotiator_subagent( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> SubAgent: + """Build the ``connector_negotiator`` subagent spec. + + Inherits all MCP / connector tools the parent has plus + ``get_connected_accounts``. Read-only by default; permission rules deny + write/mutation patterns. The parent agent re-asks for permission if a + connector mutation is genuinely needed. + + Pass ``extra_middleware`` (typically the parent's filesystem + todo + middleware) so this subagent shares the parent's filesystem view when + citing evidence across hops. + """ + parent_tool_names = {t.name for t in tools} + allowed: set[str] = set() + if "get_connected_accounts" in parent_tool_names: + allowed.add("get_connected_accounts") + # Inherit anything that smells connector- or MCP-related but is not a + # bulk-write API. Heuristic: keep all parent tools; rely on the deny + # ruleset to block mutation patterns. This mirrors the plan: "all + # MCP/connector tools the parent has". + for name in parent_tool_names: + allowed.add(name) + selected_tools = _filter_tools(tools, allowed) + + deny_rules = _read_only_deny_rules() + permission_mw = _build_permission_middleware( + deny_rules, origin="subagent_connector_negotiator" + ) + + spec: dict = { + "name": "connector_negotiator", + "description": ( + "Coordinate read-only chains across connectors (Slack → Linear, " + "Notion → Gmail, etc.). Returns a structured summary with the " + "evidence chain. Cannot mutate connector state." + ), + "system_prompt": CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT, + "tools": selected_tools, + "middleware": _wrap_with_subagent_essentials( + [permission_mw], + agent_tools=selected_tools, + extra_middleware=extra_middleware, + ), + "skills": default_skills_sources(), + } + if model is not None: + spec["model"] = model + return spec # type: ignore[return-value] + + +def build_specialized_subagents( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> list[SubAgent]: + """Return the canonical list of specialized subagents to register. + + Order matters only for the order they appear in the ``task`` tool + description — most useful first. + """ + return [ + build_explore_subagent( + tools=tools, model=model, extra_middleware=extra_middleware + ), + build_report_writer_subagent( + tools=tools, model=model, extra_middleware=extra_middleware + ), + build_connector_negotiator_subagent( + tools=tools, model=model, extra_middleware=extra_middleware + ), + ] diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 0c9426892..3919527d9 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -1,842 +1,43 @@ """ -System prompt building for SurfSense agents. +Thin compatibility wrapper around :mod:`app.agents.new_chat.prompts.composer`. -This module provides functions and constants for building the SurfSense system prompt -with configurable user instructions and citation support. +Tier 3a of the OpenCode-port plan replaced the monolithic prompt strings +in this module with a fragment tree under ``prompts/`` and a composer +function. This module preserves the public function surface +(``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` / +``get_default_system_instructions`` / ``SURFSENSE_SYSTEM_PROMPT``) so that +existing call sites — `chat_deepagent.py`, anonymous chat routes, and the +configurable-prompt admin path — keep working without churn. -The prompt is composed of three parts: -1. System Instructions (configurable via NewLLMConfig) -2. Tools Instructions (always included, not configurable) -3. Citation Instructions (toggleable via NewLLMConfig.citations_enabled) +For new call sites prefer importing ``compose_system_prompt`` directly +from :mod:`app.agents.new_chat.prompts.composer`. """ +from __future__ import annotations + from datetime import UTC, datetime from app.db import ChatVisibility -# Default system instructions - can be overridden via NewLLMConfig.system_instructions -SURFSENSE_SYSTEM_INSTRUCTIONS = """ - -You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base. - -Today's date (UTC): {resolved_today} - -When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. - -NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. - - -CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: -- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs. -- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission. -- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: - 1. Inform the user that you could not find relevant information in their knowledge base. - 2. Ask the user: "Would you like me to answer from my general knowledge instead?" - 3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes. -- This policy does NOT apply to: - * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") - * Formatting, summarization, or analysis of content already present in the conversation - * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") - * Tool-usage actions like generating reports, podcasts, images, or scraping webpages - * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below - - - -CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. -Their data is NEVER in the knowledge base. You MUST call their tools immediately — never -say "I don't see it in the knowledge base" or ask the user if they want you to check. -Ignore any knowledge base results for these services. - -When to use which tool: -- Linear (issues) → list_issues, get_issue, save_issue (create/update) -- ClickUp (tasks) → clickup_search, clickup_get_task -- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue -- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread -- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table -- Knowledge base content (Notion, GitHub, files, notes) → automatically searched -- Real-time public web data → call web_search -- Reading a specific webpage → call scrape_webpage - - - -Some service tools require identifiers or context you do not have (account IDs, -workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw -IDs or technical identifiers — they cannot memorise them. - -Instead, follow this discovery pattern: -1. Call a listing/discovery tool to find available options. -2. ONE result → use it silently, no question to the user. -3. MULTIPLE results → present the options by their display names and let the - user choose. Never show raw UUIDs — always use friendly names. - -Discovery tools by level: -- Which account/workspace? → get_connected_accounts("") -- Which Jira site (cloudId)? → getAccessibleAtlassianResources -- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) -- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) -- Which channel? → slack_search_channels -- Which base? → list_bases -- Which table? → list_tables_for_base (after resolving baseId) -- Which task? → clickup_search -- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) - -For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to -obtain the cloudId, then pass it to other Jira tools. When creating an issue, -chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. -If there is only one option at each step, use it silently. If multiple, present -friendly names. - -Chain discovery when needed — e.g. for Airtable records: list_bases → pick -base → list_tables_for_base → pick table → list_records_for_table. - -MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for -the same service, tool names are prefixed to avoid collisions — e.g. -linear_25_list_issues and linear_30_list_issues instead of two list_issues. -Each prefixed tool's description starts with [Account: ] so you -know which account it targets. Use get_connected_accounts("") to see -the full list of accounts with their connector IDs and display names. -When only one account is connected, tools have their normal unprefixed names. - - - -IMPORTANT — After understanding each user message, ALWAYS check: does this message -reveal durable facts about the user (role, interests, preferences, projects, -background, or standing instructions)? If yes, you MUST call update_memory -alongside your normal response — do not defer this to a later turn. - - - -""" - -# Default system instructions for shared (team) threads: team context + message format for attribution -_SYSTEM_INSTRUCTIONS_SHARED = """ - -You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base. - -In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers. - -Today's date (UTC): {resolved_today} - -When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. - -NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. - - -CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: -- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs. -- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission. -- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: - 1. Inform the team that you could not find relevant information in the shared knowledge base. - 2. Ask: "Would you like me to answer from my general knowledge instead?" - 3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes. -- This policy does NOT apply to: - * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") - * Formatting, summarization, or analysis of content already present in the conversation - * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") - * Tool-usage actions like generating reports, podcasts, images, or scraping webpages - * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below - - - -CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. -Their data is NEVER in the knowledge base. You MUST call their tools immediately — never -say "I don't see it in the knowledge base" or ask if they want you to check. -Ignore any knowledge base results for these services. - -When to use which tool: -- Linear (issues) → list_issues, get_issue, save_issue (create/update) -- ClickUp (tasks) → clickup_search, clickup_get_task -- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue -- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread -- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table -- Knowledge base content (Notion, GitHub, files, notes) → automatically searched -- Real-time public web data → call web_search -- Reading a specific webpage → call scrape_webpage - - - -Some service tools require identifiers or context you do not have (account IDs, -workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw -IDs or technical identifiers — they cannot memorise them. - -Instead, follow this discovery pattern: -1. Call a listing/discovery tool to find available options. -2. ONE result → use it silently, no question to the user. -3. MULTIPLE results → present the options by their display names and let the - user choose. Never show raw UUIDs — always use friendly names. - -Discovery tools by level: -- Which account/workspace? → get_connected_accounts("") -- Which Jira site (cloudId)? → getAccessibleAtlassianResources -- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) -- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) -- Which channel? → slack_search_channels -- Which base? → list_bases -- Which table? → list_tables_for_base (after resolving baseId) -- Which task? → clickup_search -- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) - -For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to -obtain the cloudId, then pass it to other Jira tools. When creating an issue, -chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. -If there is only one option at each step, use it silently. If multiple, present -friendly names. - -Chain discovery when needed — e.g. for Airtable records: list_bases → pick -base → list_tables_for_base → pick table → list_records_for_table. - -MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for -the same service, tool names are prefixed to avoid collisions — e.g. -linear_25_list_issues and linear_30_list_issues instead of two list_issues. -Each prefixed tool's description starts with [Account: ] so you -know which account it targets. Use get_connected_accounts("") to see -the full list of accounts with their connector IDs and display names. -When only one account is connected, tools have their normal unprefixed names. - - - -IMPORTANT — After understanding each user message, ALWAYS check: does this message -reveal durable facts about the team (decisions, conventions, architecture, processes, -or key facts)? If yes, you MUST call update_memory alongside your normal response — -do not defer this to a later turn. - - - -""" - - -def _get_system_instructions( - thread_visibility: ChatVisibility | None = None, today: datetime | None = None -) -> str: - """Build system instructions based on thread visibility (private vs shared).""" - - resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() - visibility = thread_visibility or ChatVisibility.PRIVATE - if visibility == ChatVisibility.SEARCH_SPACE: - return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today) - else: - return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today) - - -# ============================================================================= -# Per-tool prompt instructions keyed by registry tool name. -# Only tools present in the enabled set will be included in the system prompt. -# ============================================================================= - -_TOOLS_PREAMBLE = """ - -You have access to the following tools: - -IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it. -Do NOT claim you can do something if the corresponding tool is not listed. - -""" - -_TOOL_INSTRUCTIONS: dict[str, str] = {} - -_TOOL_INSTRUCTIONS["search_surfsense_docs"] = """ -- search_surfsense_docs: Search the official SurfSense documentation. - - Use this tool when the user asks anything about SurfSense itself (the application they are using). - - Args: - - query: The search query about SurfSense - - top_k: Number of documentation chunks to retrieve (default: 10) - - Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123]) -""" - -_TOOL_INSTRUCTIONS["generate_podcast"] = """ -- generate_podcast: Generate an audio podcast from provided content. - - Use this when the user asks to create, generate, or make a podcast. - - Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast" - - Args: - - source_content: The text content to convert into a podcast. This MUST be comprehensive and include: - * If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses) - * If based on knowledge base search: Include the key findings and insights from the search results - * You can combine both: conversation context + search results for richer podcasts - * The more detailed the source_content, the better the podcast quality - - podcast_title: Optional title for the podcast (default: "SurfSense Podcast") - - user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun") - - Returns: A task_id for tracking. The podcast will be generated in the background. - - IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating". - - After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes). -""" - -_TOOL_INSTRUCTIONS["generate_video_presentation"] = """ -- generate_video_presentation: Generate a video presentation from provided content. - - Use this when the user asks to create a video, presentation, slides, or slide deck. - - Trigger phrases: "give me a presentation", "create slides", "generate a video", "make a slide deck", "turn this into a presentation" - - Args: - - source_content: The text content to turn into a presentation. The more detailed, the better. - - video_title: Optional title (default: "SurfSense Presentation") - - user_prompt: Optional style instructions (e.g., "Make it technical and detailed") - - After calling this tool, inform the user that generation has started and they will see the presentation when it's ready. -""" - -_TOOL_INSTRUCTIONS["generate_report"] = """ -- generate_report: Generate or revise a structured Markdown report artifact. - - WHEN TO CALL THIS TOOL — the message must contain a creation or modification VERB directed at producing a deliverable: - * Creation verbs: write, create, generate, draft, produce, summarize into, turn into, make - * Modification verbs: revise, update, expand, add (a section), rewrite, make (it shorter/longer/formal) - * Example triggers: "generate a report about...", "write a document on...", "add a section about budget", "make the report shorter", "rewrite in formal tone" - - WHEN NOT TO CALL THIS TOOL (answer in chat instead): - * Questions or discussion about the report: "What can we add?", "What's missing?", "Is the data accurate?", "How could this be improved?" - * Suggestions or brainstorming: "What other topics could be covered?", "What else could be added?", "What would make this better?" - * Asking for explanations: "Can you explain section 2?", "Why did you include that?", "What does this part mean?" - * Quick follow-ups or critiques: "Is the conclusion strong enough?", "Are there any gaps?", "What about the competitors?" - * THE TEST: Does the message contain a creation/modification VERB (from the list above) directed at producing or changing a deliverable? If NO verb → answer conversationally in chat. Do NOT assume the user wants a revision just because a report exists in the conversation. - - IMPORTANT FORMAT RULE: Reports are ALWAYS generated in Markdown. - - Args: - - topic: Short title for the report (max ~8 words). - - source_content: The text content to base the report on. - * For source_strategy="conversation" or "provided": Include a comprehensive summary of the relevant content. - * For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally. - * For source_strategy="auto": Include what you have; the tool searches KB if it's not enough. - - source_strategy: Controls how the tool collects source material. One of: - * "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. - * "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. - * "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries. - * "provided" — Use only what is in source_content (default, backward-compatible). - - search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated. - - report_style: Controls report depth. Options: "detailed" (DEFAULT), "deep_research", "brief". - Use "brief" ONLY when the user explicitly asks for a short/concise/one-page report (e.g., "one page", "keep it short", "brief report", "500 words"). Default to "detailed" for all other requests. - - user_instructions: Optional specific instructions (e.g., "focus on financial impacts", "include recommendations"). When revising (parent_report_id set), describe WHAT TO CHANGE. If the user mentions a length preference (e.g., "one page", "500 words", "2 pages"), include that VERBATIM here AND set report_style="brief". - - parent_report_id: Set this to the report_id from a previous generate_report result when the user wants to MODIFY an existing report. Do NOT set it for new reports or questions about reports. - - Returns: A dictionary with status "ready" or "failed", report_id, title, and word_count. - - The report is generated immediately in Markdown and displayed inline in the chat. - - Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report. - - SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly): - * If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content. - * If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries. - * If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries. - * When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content. - * NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally. - - AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat. -""" - -_TOOL_INSTRUCTIONS["generate_image"] = """ -- generate_image: Generate images from text descriptions using AI image models. - - Use this when the user asks you to create, generate, draw, design, or make an image. - - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" - - Args: - - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. - - n: Number of images to generate (1-4, default: 1) - - Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat. - - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - - expand and improve the prompt with specific details about style, lighting, composition, and mood. - - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. -""" - -_TOOL_INSTRUCTIONS["scrape_webpage"] = """ -- scrape_webpage: Scrape and extract the main content from a webpage. - - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. - - CRITICAL — WHEN TO USE (always attempt scraping, never refuse before trying): - * When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL - * When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices) - * When a URL was mentioned earlier in the conversation and the user asks for its actual content - * When `/documents/` knowledge-base data is insufficient and the user wants more - - Trigger scenarios: - * "Read this article and summarize it" - * "What does this page say about X?" - * "Summarize this blog post for me" - * "Tell me the key points from this article" - * "What's in this webpage?" - * "Can you analyze this article?" - * "Can you get the live table/data from [URL]?" - * "Scrape it" / "Can you scrape that?" (referring to a previously mentioned URL) - * "Fetch the content from [URL]" - * "Pull the data from that page" - - Args: - - url: The URL of the webpage to scrape (must be HTTP/HTTPS) - - max_length: Maximum content length to return (default: 50000 chars) - - Returns: The page title, description, full content (in markdown), word count, and metadata - - After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points. - - Reference the source using markdown links [descriptive text](url) — never bare URLs. - - IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`. - * When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: `![alt text](image_url)`. - * This makes your response more visual and engaging. - * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. - * Don't show every image - just the most relevant 1-3 images that enhance understanding. -""" - -_TOOL_INSTRUCTIONS["web_search"] = """ -- web_search: Search the web for real-time information using all configured search engines. - - Use this for current events, news, prices, weather, public facts, or any question requiring - up-to-date information from the internet. - - This tool dispatches to all configured search engines (SearXNG, Tavily, Linkup, Baidu) in - parallel and merges the results. - - IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data - (e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call - `web_search` instead of answering from memory. - - For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet - access before attempting a web search. - - If the search returns no relevant results, explain that web sources did not return enough - data and ask the user if they want you to retry with a refined query. - - Args: - - query: The search query - use specific, descriptive terms - - top_k: Number of results to retrieve (default: 10, max: 50) - - If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content. - - When presenting results, reference sources as markdown links [descriptive text](url) — never bare URLs. -""" - -# Memory tool instructions have private and shared variants. -# We store them keyed as "update_memory" with sub-keys. -_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = { - "update_memory": { - "private": """ -- update_memory: Update your personal memory document about the user. - - Your current memory is already in in your context. The `chars` and - `limit` attributes show your current usage and the maximum allowed size. - - This is your curated long-term memory — the distilled essence of what you know about - the user, not raw conversation logs. - - Call update_memory when: - * The user explicitly asks to remember or forget something - * The user shares durable facts or preferences that will matter in future conversations - - The user's first name is provided in . Use it in memory entries - instead of "the user" (e.g. "{name} works at..." not "The user works at..."). - Do not store the name itself as a separate memory entry. - - Do not store short-lived or ephemeral info: one-off questions, greetings, - session logistics, or things that only matter for the current task. - - Args: - - updated_memory: The FULL updated markdown document (not a diff). - Merge new facts with existing ones, update contradictions, remove outdated entries. - Treat every update as a curation pass — consolidate, don't just append. - - Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text - Markers: - [fact] — durable facts (role, background, projects, tools, expertise) - [pref] — preferences (response style, languages, formats, tools) - [instr] — standing instructions (always/never do, response rules) - - Keep it concise and well under the character limit shown in . - - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and - natural. Do NOT include the user's name in headings. Organize by context — e.g. - who they are, what they're focused on, how they prefer things. Create, split, or - merge headings freely as the memory grows. - - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant - details and context rather than just a few words. - - During consolidation, prioritize keeping: [instr] > [pref] > [fact]. -""", - "shared": """ -- update_memory: Update the team's shared memory document for this search space. - - Your current team memory is already in in your context. The `chars` - and `limit` attributes show current usage and the maximum allowed size. - - This is the team's curated long-term memory — decisions, conventions, key facts. - - NEVER store personal memory in team memory (e.g. personal bio, individual - preferences, or user-only standing instructions). - - Call update_memory when: - * A team member explicitly asks to remember or forget something - * The conversation surfaces durable team decisions, conventions, or facts - that will matter in future conversations - - Do not store short-lived or ephemeral info: one-off questions, greetings, - session logistics, or things that only matter for the current task. - - Args: - - updated_memory: The FULL updated markdown document (not a diff). - Merge new facts with existing ones, update contradictions, remove outdated entries. - Treat every update as a curation pass — consolidate, don't just append. - - Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text - Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory. - - Keep it concise and well under the character limit shown in . - - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and - natural. Organize by context — e.g. what the team decided, current architecture, - active processes. Create, split, or merge headings freely as the memory grows. - - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant - details and context rather than just a few words. - - During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities. -""", - }, -} - -_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = { - "update_memory": { - "private": """ -- Alex, is empty. User: "I'm a space enthusiast, explain astrophage to me" - - The user casually shared a durable fact. Use their first name in the entry, short neutral heading: - update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n") -- User: "Remember that I prefer concise answers over detailed explanations" - - Durable preference. Merge with existing memory, add a new heading: - update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n\\n## Response style\\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\\n") -- User: "I actually moved to Tokyo last month" - - Updated fact, date prefix reflects when recorded: - update_memory(updated_memory="## Interests & background\\n...\\n\\n## Personal context\\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\\n...") -- User: "I'm a freelance photographer working on a nature documentary" - - Durable background info under a fitting heading: - update_memory(updated_memory="...\\n\\n## Current focus\\n- (2025-03-15) [fact] Alex is a freelance photographer\\n- (2025-03-15) [fact] Alex is working on a nature documentary\\n") -- User: "Always respond in bullet points" - - Standing instruction: - update_memory(updated_memory="...\\n\\n## Response style\\n- (2025-03-15) [instr] Always respond to Alex in bullet points\\n") -""", - "shared": """ -- User: "Let's remember that we decided to do weekly standup meetings on Mondays" - - Durable team decision: - update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\\n...") -- User: "Our office is in downtown Seattle, 5th floor" - - Durable team fact: - update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\\n...") -""", - }, -} - -# Per-tool examples keyed by tool name. Only examples for enabled tools are included. -_TOOL_EXAMPLES: dict[str, str] = {} - -_TOOL_EXAMPLES["search_surfsense_docs"] = """ -- User: "How do I install SurfSense?" - - Call: `search_surfsense_docs(query="installation setup")` -- User: "What connectors does SurfSense support?" - - Call: `search_surfsense_docs(query="available connectors integrations")` -- User: "How do I set up the Notion connector?" - - Call: `search_surfsense_docs(query="Notion connector setup configuration")` -- User: "How do I use Docker to run SurfSense?" - - Call: `search_surfsense_docs(query="Docker installation setup")` -""" - -_TOOL_EXAMPLES["generate_podcast"] = """ -- User: "Give me a podcast about AI trends based on what we discussed" - - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` -- User: "Create a podcast summary of this conversation" - - Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")` -- User: "Make a podcast about quantum computing" - - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")` -""" - -_TOOL_EXAMPLES["generate_video_presentation"] = """ -- User: "Give me a presentation about AI trends based on what we discussed" - - First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")` -- User: "Create slides summarizing this conversation" - - Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")` -- User: "Make a video presentation about quantum computing" - - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")` -""" - -_TOOL_EXAMPLES["generate_report"] = """ -- User: "Generate a report about AI trends" - - Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")` - - WHY: Has creation verb "generate" → call the tool. No prior discussion → use kb_search. -- User: "Write a research report from this conversation" - - Call: `generate_report(topic="Research Report", source_strategy="conversation", source_content="Complete conversation summary:\\n\\n...", report_style="deep_research")` - - WHY: Has creation verb "write" → call the tool. Conversation has the content → use source_strategy="conversation". -- User: (after a report on Climate Change was generated) "Add a section about carbon capture technologies" - - Call: `generate_report(topic="Climate Crisis: Causes, Impacts, and Solutions", source_strategy="conversation", source_content="[summary of conversation context if any]", parent_report_id=, user_instructions="Add a new section about carbon capture technologies")` - - WHY: Has modification verb "add" + specific deliverable target → call the tool with parent_report_id. -- User: (after a report was generated) "What else could we add to have more depth?" - - Do NOT call generate_report. Answer in chat with suggestions. - - WHY: No creation/modification verb directed at producing a deliverable. -""" - -_TOOL_EXAMPLES["scrape_webpage"] = """ -- User: "Check out https://dev.to/some-article" - - Call: `scrape_webpage(url="https://dev.to/some-article")` - - Respond with a structured analysis — key points, takeaways. -- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends" - - Call: `scrape_webpage(url="https://example.com/blog/ai-trends")` - - Respond with a thorough summary using headings and bullet points. -- User: (after discussing https://example.com/stats) "Can you get the live data from that page?" - - Call: `scrape_webpage(url="https://example.com/stats")` - - IMPORTANT: Always attempt scraping first. Never refuse before trying the tool. -- User: "https://example.com/blog/weekend-recipes" - - Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")` - - When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content. -""" - -_TOOL_EXAMPLES["generate_image"] = """ -- User: "Generate an image of a cat" - - Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` - - The generated image will automatically be displayed in the chat. -- User: "Draw me a logo for a coffee shop called Bean Dream" - - Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` - - The generated image will automatically be displayed in the chat. -- User: "Show me this image: https://example.com/image.png" - - Simply include it in your response using markdown: `![Image](https://example.com/image.png)` -- User uploads an image file and asks: "What is this image about?" - - The user's uploaded image is already visible in the chat. - - Simply analyze the image content and respond directly. -""" - -_TOOL_EXAMPLES["web_search"] = """ -- User: "What's the current USD to INR exchange rate?" - - Call: `web_search(query="current USD to INR exchange rate")` - - Then answer using the returned web results with citations. -- User: "What's the latest news about AI?" - - Call: `web_search(query="latest AI news today")` -- User: "What's the weather in New York?" - - Call: `web_search(query="weather New York today")` -""" - -_TOOL_INSTRUCTIONS["generate_resume"] = """ -- generate_resume: Generate or revise a professional resume as a Typst document. - - WHEN TO CALL: The user asks to create, build, generate, write, or draft a resume or CV. - Also when they ask to modify, update, or revise an existing resume from this conversation. - - WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing - a resume without making changes. For cover letters, use generate_report instead. - - The tool produces Typst source code that is compiled to a PDF preview automatically. - - PAGE POLICY: - - Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more. - - If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value. - - Args: - - user_info: The user's resume content — work experience, education, skills, contact - info, etc. Can be structured or unstructured text. - CRITICAL: user_info must be COMPREHENSIVE. Do NOT just pass the user's raw message. - You MUST gather and consolidate ALL available information: - * Content from referenced/mentioned documents (e.g., uploaded resumes, CVs, LinkedIn profiles) - that appear in the conversation context — extract and include their FULL content. - * Information the user shared across multiple messages in the conversation. - * Any relevant details from knowledge base search results in the context. - The more complete the user_info, the better the resume. Include names, contact info, - work experience with dates, education, skills, projects, certifications — everything available. - - user_instructions: Optional style or content preferences (e.g. "emphasize leadership", - "keep it to one page"). For revisions, describe what to change. - - parent_report_id: Set this when the user wants to MODIFY an existing resume from - this conversation. Use the report_id from a previous generate_resume result. - - max_pages: Maximum resume length in pages (integer 1-5). Default is 1. - - Returns: Dict with status, report_id, title, and content_type. - - After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically. - - VERSIONING: Same rules as generate_report — set parent_report_id for modifications - of an existing resume, leave as None for new resumes. -""" - -_TOOL_EXAMPLES["generate_resume"] = """ -- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..." - - Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)` - - WHY: Has creation verb "build" + resume → call the tool. -- User: "Create my CV with this info: [experience, education, skills]" - - Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)` -- User: "Build me a resume" (and there is a resume/CV document in the conversation context) - - Extract the FULL content from the document in context, then call: - `generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...", max_pages=1)` - - WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents. -- User: (after resume generated) "Change my title to Senior Engineer" - - Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=, max_pages=1)` - - WHY: Modification verb "change" + refers to existing resume → set parent_report_id. -- User: (after resume generated) "Make this 2 pages and expand projects" - - Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=, max_pages=2)` - - WHY: Explicit page increase request → set max_pages to 2. -- User: "How should I structure my resume?" - - Do NOT call generate_resume. Answer in chat with advice. - - WHY: No creation/modification verb. -""" - -# All tool names that have prompt instructions (order matters for prompt readability) -_ALL_TOOL_NAMES_ORDERED = [ - "search_surfsense_docs", - "web_search", - "generate_podcast", - "generate_video_presentation", - "generate_report", - "generate_resume", - "generate_image", - "scrape_webpage", - "update_memory", -] - - -def _format_tool_name(name: str) -> str: - """Convert snake_case tool name to a human-readable label.""" - return name.replace("_", " ").title() - - -def _get_tools_instructions( - thread_visibility: ChatVisibility | None = None, - enabled_tool_names: set[str] | None = None, - disabled_tool_names: set[str] | None = None, -) -> str: - """Build tools instructions containing only the enabled tools. - - Args: - thread_visibility: Private vs shared — affects memory tool wording. - enabled_tool_names: Set of tool names that are actually bound to the agent. - When None, all tools are included (backward-compatible default). - disabled_tool_names: Set of tool names that the user explicitly disabled. - When provided, a note is appended telling the model about these tools - so it can inform the user they can re-enable them. - """ - visibility = thread_visibility or ChatVisibility.PRIVATE - memory_variant = ( - "shared" if visibility == ChatVisibility.SEARCH_SPACE else "private" - ) - - parts: list[str] = [_TOOLS_PREAMBLE] - examples: list[str] = [] - - for tool_name in _ALL_TOOL_NAMES_ORDERED: - if enabled_tool_names is not None and tool_name not in enabled_tool_names: - continue - - if tool_name in _TOOL_INSTRUCTIONS: - parts.append(_TOOL_INSTRUCTIONS[tool_name]) - elif tool_name in _MEMORY_TOOL_INSTRUCTIONS: - parts.append(_MEMORY_TOOL_INSTRUCTIONS[tool_name][memory_variant]) - - if tool_name in _TOOL_EXAMPLES: - examples.append(_TOOL_EXAMPLES[tool_name]) - elif tool_name in _MEMORY_TOOL_EXAMPLES: - examples.append(_MEMORY_TOOL_EXAMPLES[tool_name][memory_variant]) - - # Append a note about user-disabled tools so the model can inform the user - known_disabled = ( - disabled_tool_names & set(_ALL_TOOL_NAMES_ORDERED) - if disabled_tool_names - else set() - ) - if known_disabled: - disabled_list = ", ".join( - _format_tool_name(n) for n in _ALL_TOOL_NAMES_ORDERED if n in known_disabled - ) - parts.append(f""" -DISABLED TOOLS (by user): -The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}. -You do NOT have access to these tools and MUST NOT claim you can use them. -If the user asks about a capability provided by a disabled tool, let them know the relevant tool -is currently disabled and they can re-enable it. -""") - - parts.append("\n\n") - - if examples: - parts.append("") - parts.extend(examples) - parts.append("\n") - - return "".join(parts) - - -# Backward-compatible constant: all tools included (private memory variant) -SURFSENSE_TOOLS_INSTRUCTIONS = _get_tools_instructions() - - -SURFSENSE_CITATION_INSTRUCTIONS = """ - -CRITICAL CITATION REQUIREMENTS: - -1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `` tag inside ``. -2. Make sure ALL factual statements from the documents have proper citations. -3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2]. -4. You MUST use the exact chunk_id values from the `` attributes. Do not create your own citation numbers. -5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value. -6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags. -7. Do not return citations as clickable links. -8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only. -9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting. -10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `` tags. -11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up. - - -The documents you receive are structured like this: - -**Knowledge base documents (numeric chunk IDs):** - - - 42 - GITHUB_CONNECTOR - <![CDATA[Some repo / file / issue title]]> - - - - - - - - - - -**Web search results (URL chunk IDs):** - - - WEB_SEARCH - <![CDATA[Some web search result]]> - - - - - - - - -IMPORTANT: You MUST cite using the EXACT chunk ids from the `` tags. -- For knowledge base documents, chunk ids are numeric (e.g. 123, 124) or prefixed (e.g. doc-45). -- For live web search results, chunk ids are URLs (e.g. https://example.com/article). -Do NOT cite document_id. Always use the chunk id. - - - -- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `` tag -- Citations should appear at the end of the sentence containing the information they support -- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] -- No need to return references section. Just citations in answer. -- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format -- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only -- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess -- Copy the EXACT chunk id from the XML - if it says ``, use [citation:doc-123] -- If the chunk id is a URL like ``, use [citation:https://example.com/page] - - - -CORRECT citation formats: -- [citation:5] (numeric chunk ID from knowledge base) -- [citation:doc-123] (for Surfsense documentation chunks) -- [citation:https://example.com/article] (URL chunk ID from web search results) -- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations) - -INCORRECT citation formats (DO NOT use): -- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense)) -- Using parentheses around brackets: ([citation:5]) -- Using hyperlinked text: [link to source 5](https://example.com) -- Using footnote style: ... library¹ -- Making up source IDs when source_id is unknown -- Using old IEEE format: [1], [2], [3] -- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5] - - - -Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5]. - -According to web search results, the key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:https://docs.python.org/3/library/asyncio.html]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources. - -However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead. - - -""" - -# Anti-citation prompt - used when citations are disabled -# This explicitly tells the model NOT to include citations -SURFSENSE_NO_CITATION_INSTRUCTIONS = """ - -IMPORTANT: Citations are DISABLED for this configuration. - -DO NOT include any citations in your responses. Specifically: -1. Do NOT use the [citation:chunk_id] format anywhere in your response. -2. Do NOT reference document IDs, chunk IDs, or source IDs. -3. Simply provide the information naturally without any citation markers. -4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly. - -When answering questions based on documents from the knowledge base: -- Present the information directly and confidently -- Do not mention that information comes from specific documents or chunks -- Integrate facts naturally into your response without attribution markers - -Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation. - -""" - - -def _build_mcp_routing_block( - mcp_connector_tools: dict[str, list[str]] | None, -) -> str: - """Build an additional tool routing block for generic MCP connectors. - - When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know - those tools exist and should be called directly — not searched in the - knowledge base. - """ - if not mcp_connector_tools: - return "" - - lines = [ - "\n", - "You also have direct tools from these user-connected MCP servers.", - "Their data is NEVER in the knowledge base — call their tools directly.", - "", - ] - for server_name, tool_names in mcp_connector_tools.items(): - lines.append(f"- {server_name} → {', '.join(tool_names)}") - lines.append("\n") - return "\n".join(lines) +from .prompts.composer import ( + _read_fragment, + compose_system_prompt, + detect_provider_variant, +) + +# Public re-exports for backwards compatibility (some legacy code reads the +# raw default-instructions text directly). +SURFSENSE_SYSTEM_INSTRUCTIONS_TEMPLATE = ( + "\nDefault SurfSense agent system instructions are now\n" + "composed from prompts/base/*.md. See compose_system_prompt() for details.\n" + "" +) + +# Citation block re-exposed for legacy importers that referenced this constant +# directly. The composer is the canonical source; this is a frozen snapshot +# loaded at module-init time. +SURFSENSE_CITATION_INSTRUCTIONS = _read_fragment("base/citations_on.md") +SURFSENSE_NO_CITATION_INSTRUCTIONS = _read_fragment("base/citations_off.md") def build_surfsense_system_prompt( @@ -845,36 +46,23 @@ def build_surfsense_system_prompt( enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, mcp_connector_tools: dict[str, list[str]] | None = None, + *, + model_name: str | None = None, ) -> str: + """Build the default SurfSense system prompt (citations on, defaults). + + See :func:`app.agents.new_chat.prompts.composer.compose_system_prompt` + for full parameter docs. """ - Build the SurfSense system prompt with default settings. - - This is a convenience function that builds the prompt with: - - Default system instructions - - Tools instructions (only for enabled tools) - - Citation instructions enabled - - Args: - today: Optional datetime for today's date (defaults to current UTC date) - thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. - enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. - disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. - mcp_connector_tools: Mapping of MCP server display name → list of tool names - for generic MCP connectors. Injected into the system prompt so the LLM - knows to call these tools directly. - - Returns: - Complete system prompt string - """ - - visibility = thread_visibility or ChatVisibility.PRIVATE - system_instructions = _get_system_instructions(visibility, today) - system_instructions += _build_mcp_routing_block(mcp_connector_tools) - tools_instructions = _get_tools_instructions( - visibility, enabled_tool_names, disabled_tool_names + return compose_system_prompt( + today=today, + thread_visibility=thread_visibility, + enabled_tool_names=enabled_tool_names, + disabled_tool_names=disabled_tool_names, + mcp_connector_tools=mcp_connector_tools, + citations_enabled=True, + model_name=model_name, ) - citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS - return system_instructions + tools_instructions + citation_instructions def build_configurable_system_prompt( @@ -886,75 +74,54 @@ def build_configurable_system_prompt( enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, mcp_connector_tools: dict[str, list[str]] | None = None, + *, + model_name: str | None = None, ) -> str: + """Build a configurable SurfSense system prompt (NewLLMConfig path). + + See :func:`app.agents.new_chat.prompts.composer.compose_system_prompt` + for full parameter docs. """ - Build a configurable SurfSense system prompt based on NewLLMConfig settings. - - The prompt is composed of three parts: - 1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS - 2. Tools Instructions - only for enabled tools, with a note about disabled ones - 3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS - - Args: - custom_system_instructions: Custom system instructions to use. If empty/None and - use_default_system_instructions is True, defaults to - SURFSENSE_SYSTEM_INSTRUCTIONS. - use_default_system_instructions: Whether to use default instructions when - custom_system_instructions is empty/None. - citations_enabled: Whether to include citation instructions (True) or - anti-citation instructions (False). - today: Optional datetime for today's date (defaults to current UTC date) - thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. - enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. - disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. - mcp_connector_tools: Mapping of MCP server display name → list of tool names - for generic MCP connectors. Injected into the system prompt so the LLM - knows to call these tools directly. - - Returns: - Complete system prompt string - """ - resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() - - # Determine system instructions - if custom_system_instructions and custom_system_instructions.strip(): - system_instructions = custom_system_instructions.format( - resolved_today=resolved_today - ) - elif use_default_system_instructions: - visibility = thread_visibility or ChatVisibility.PRIVATE - system_instructions = _get_system_instructions(visibility, today) - else: - system_instructions = "" - - system_instructions += _build_mcp_routing_block(mcp_connector_tools) - - # Tools instructions: only include enabled tools, note disabled ones - tools_instructions = _get_tools_instructions( - thread_visibility, enabled_tool_names, disabled_tool_names + return compose_system_prompt( + today=today, + thread_visibility=thread_visibility, + enabled_tool_names=enabled_tool_names, + disabled_tool_names=disabled_tool_names, + mcp_connector_tools=mcp_connector_tools, + custom_system_instructions=custom_system_instructions, + use_default_system_instructions=use_default_system_instructions, + citations_enabled=citations_enabled, + model_name=model_name, ) - # Citation instructions based on toggle - citation_instructions = ( - SURFSENSE_CITATION_INSTRUCTIONS - if citations_enabled - else SURFSENSE_NO_CITATION_INSTRUCTIONS - ) - - return system_instructions + tools_instructions + citation_instructions - def get_default_system_instructions() -> str: + """Return the default ```` block (no tools / citations). + + Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. + The output reflects the current fragment tree, not a baked-in constant. """ - Get the default system instructions template. + resolved_today = datetime.now(UTC).date().isoformat() + from .prompts.composer import _build_system_instructions # local import - This is useful for populating the UI with the default value when - creating a new NewLLMConfig. - - Returns: - Default system instructions string (with {resolved_today} placeholder) - """ - return SURFSENSE_SYSTEM_INSTRUCTIONS.strip() + return _build_system_instructions( + visibility=ChatVisibility.PRIVATE, + resolved_today=resolved_today, + ).strip() +# Backwards compatibility — some modules import the constant directly. SURFSENSE_SYSTEM_PROMPT = build_surfsense_system_prompt() + + +__all__ = [ + "SURFSENSE_CITATION_INSTRUCTIONS", + "SURFSENSE_NO_CITATION_INSTRUCTIONS", + "SURFSENSE_SYSTEM_INSTRUCTIONS_TEMPLATE", + "SURFSENSE_SYSTEM_PROMPT", + "build_configurable_system_prompt", + "build_surfsense_system_prompt", + "compose_system_prompt", + "detect_provider_variant", + "get_default_system_instructions", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py b/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py new file mode 100644 index 000000000..df10fcbe3 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py @@ -0,0 +1,52 @@ +""" +The ``invalid`` fallback tool. + +When the model emits a tool call whose name doesn't match any registered +tool, :class:`ToolCallNameRepairMiddleware` rewrites the call to ``invalid`` +with the original name and a parser/validation error string. This tool's +execution then returns that error to the model so it can self-correct. + +Mirrors ``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.6 in +the OpenCode-port plan. + +Critically, the :class:`ToolDefinition` for this tool is **excluded** from +the system-prompt tool list and from ``LLMToolSelectorMiddleware`` selection +(see ``ToolDefinition.always_include`` filtering in the registry) — the +model never advertises ``invalid`` as a callable. It only ever shows up +in the tool registry so LangGraph can dispatch the rewritten call. +""" + +from __future__ import annotations + +from langchain_core.tools import tool + +INVALID_TOOL_NAME = "invalid" +INVALID_TOOL_DESCRIPTION = "Do not use" + + +def _format_invalid_message(tool: str | None, error: str | None) -> str: + """Return the user-visible error string. Mirrors ``invalid.ts``.""" + name = tool or "" + detail = error or "(no error message provided)" + return ( + f"The arguments provided to the tool `{name}` are invalid: {detail}\n" + f"Read the tool's docstring carefully and try again with valid arguments." + ) + + +@tool(name_or_callable=INVALID_TOOL_NAME, description=INVALID_TOOL_DESCRIPTION) +def invalid_tool(tool: str | None = None, error: str | None = None) -> str: + """Return a human-readable explanation of a tool-call validation failure. + + Activated only when :class:`ToolCallNameRepairMiddleware` rewrites a + failed tool call to ``invalid`` with the original tool name and the + error message produced during validation. + """ + return _format_invalid_message(tool, error) + + +__all__ = [ + "INVALID_TOOL_DESCRIPTION", + "INVALID_TOOL_NAME", + "invalid_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 3ac8677b9..f5ee1a61d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -43,6 +43,9 @@ from typing import Any from langchain_core.tools import BaseTool +from app.agents.new_chat.middleware.dedup_tool_calls import ( + wrap_dedup_key_by_arg_name, +) from app.db import ChatVisibility from .confluence import ( @@ -125,6 +128,14 @@ class ToolDefinition: enabled_by_default: Whether the tool is enabled when no explicit config is provided required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``) that must be in ``available_connectors`` for the tool to be enabled. + dedup_key: Optional callable that maps a tool's ``args`` dict to a + string signature used by :class:`DedupHITLToolCallsMiddleware` + to drop duplicate calls. Replaces the legacy hardcoded + ``_NATIVE_HITL_TOOL_DEDUP_KEYS`` map (Tier 2.3 in the + OpenCode-port plan). + reverse: Optional callable that, given the tool's ``(args, result)``, + returns a ``ReverseDescriptor`` describing the inverse tool + invocation. Consumed by the snapshot/revert pipeline (Tier 5). """ @@ -135,6 +146,8 @@ class ToolDefinition: enabled_by_default: bool = True hidden: bool = False required_connector: str | None = None + dedup_key: Callable[[dict[str, Any]], str] | None = None + reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None # ============================================================================= @@ -288,6 +301,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="NOTION_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("title"), ), ToolDefinition( name="update_notion_page", @@ -299,6 +313,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="NOTION_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title"), ), ToolDefinition( name="delete_notion_page", @@ -310,6 +325,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="NOTION_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title"), ), # ========================================================================= # GOOGLE DRIVE TOOLS - create files, delete files @@ -325,6 +341,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_DRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), ToolDefinition( name="delete_google_drive_file", @@ -336,6 +353,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_DRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), # ========================================================================= # DROPBOX TOOLS - create and trash files @@ -351,6 +369,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="DROPBOX_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), ToolDefinition( name="delete_dropbox_file", @@ -362,6 +381,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="DROPBOX_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), # ========================================================================= # ONEDRIVE TOOLS - create and trash files @@ -377,6 +397,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="ONEDRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), ToolDefinition( name="delete_onedrive_file", @@ -388,6 +409,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="ONEDRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), # ========================================================================= # GOOGLE CALENDAR TOOLS - search, create, update, delete events @@ -414,6 +436,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_CALENDAR_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("title"), ), ToolDefinition( name="update_calendar_event", @@ -425,6 +448,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_CALENDAR_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"), ), ToolDefinition( name="delete_calendar_event", @@ -436,6 +460,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_CALENDAR_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"), ), # ========================================================================= # GMAIL TOOLS - search, read, create drafts, update drafts, send, trash @@ -473,6 +498,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("subject"), ), ToolDefinition( name="send_gmail_email", @@ -484,6 +510,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("subject"), ), ToolDefinition( name="trash_gmail_email", @@ -495,6 +522,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("email_subject_or_id"), ), ToolDefinition( name="update_gmail_draft", @@ -506,6 +534,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("draft_subject_or_id"), ), # ========================================================================= # CONFLUENCE TOOLS - create, update, delete pages @@ -521,6 +550,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("title"), ), ToolDefinition( name="update_confluence_page", @@ -532,6 +562,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"), ), ToolDefinition( name="delete_confluence_page", @@ -543,6 +574,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"), ), # ========================================================================= # DISCORD TOOLS - list channels, read messages, send messages @@ -755,6 +787,24 @@ def build_tools( # Create the tool tool = tool_def.factory(dependencies) + # Propagate the registry-level metadata so middleware (e.g. + # ``DedupHITLToolCallsMiddleware``) and the action-log/revert + # pipeline can pick the resolvers up via ``tool.metadata`` without + # re-importing :data:`BUILTIN_TOOLS`. + if tool_def.dedup_key is not None or tool_def.reverse is not None: + existing_meta = getattr(tool, "metadata", None) or {} + merged_meta = dict(existing_meta) + if tool_def.dedup_key is not None: + merged_meta.setdefault("dedup_key", tool_def.dedup_key) + if tool_def.reverse is not None: + merged_meta.setdefault("reverse", tool_def.reverse) + try: + tool.metadata = merged_meta + except Exception: + logger.debug( + "Tool %s rejected metadata mutation; relying on registry lookup", + tool_def.name, + ) tools.append(tool) # Add any additional custom tools diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index e16590afc..fcd342d29 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -2250,6 +2250,202 @@ else: ) +class AgentActionLog(BaseModel): + """Append-only audit trail of every tool call dispatched by the agent. + + One row per ``ToolMessage`` produced; written by ``ActionLogMiddleware`` + in its ``aafter_tool`` hook. Rows are referenced by the + ``/api/threads/{thread_id}/revert/{action_id}`` route to look up an + action's stored ``reverse_descriptor`` and replay it. + + The table is intentionally narrow: large tool outputs are NOT stored + here. Result text lives in the langgraph checkpoint; this row only + keeps a short ``result_id`` (the LangChain ``ToolMessage.id`` or a + spilled-content path) for correlation. + """ + + __tablename__ = "agent_action_log" + + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + turn_id = Column(String(64), nullable=True, index=True) + message_id = Column(String(128), nullable=True, index=True) + tool_name = Column(String(255), nullable=False, index=True) + args = Column(JSONB, nullable=True) + result_id = Column(String(255), nullable=True) + reversible = Column( + Boolean, nullable=False, default=False, server_default=text("false") + ) + reverse_descriptor = Column(JSONB, nullable=True) + error = Column(JSONB, nullable=True) + reverse_of = Column( + Integer, + ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + __table_args__ = ( + Index("ix_agent_action_log_thread_created", "thread_id", "created_at"), + ) + + +class DocumentRevision(BaseModel): + """Snapshot of a :class:`Document` row taken before a mutating tool call. + + Written by :class:`KnowledgeBasePersistenceMiddleware` (or its safety-net + `commit_staged_filesystem_state`) ahead of any NOTE / FILE / EXTENSION + document write. The row is referenced by ``/revert/{action_id}`` to + restore the original content in place. + """ + + __tablename__ = "document_revisions" + + document_id = Column( + Integer, + ForeignKey("documents.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + content_before = Column(Text, nullable=True) + title_before = Column(String, nullable=True) + folder_id_before = Column(Integer, nullable=True) + chunks_before = Column(JSONB, nullable=True) + metadata_before = Column("metadata_before", JSONB, nullable=True) + created_by_turn_id = Column(String(64), nullable=True, index=True) + agent_action_id = Column( + Integer, + ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + +class FolderRevision(BaseModel): + """Snapshot of a :class:`Folder` row taken before a mkdir / move.""" + + __tablename__ = "folder_revisions" + + folder_id = Column( + Integer, + ForeignKey("folders.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + name_before = Column(String(255), nullable=True) + parent_id_before = Column(Integer, nullable=True) + position_before = Column(String(50), nullable=True) + created_by_turn_id = Column(String(64), nullable=True, index=True) + agent_action_id = Column( + Integer, + ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + +class AgentPermissionRule(BaseModel): + """Persistent permission rule consumed by :class:`PermissionMiddleware`. + + Scoped at one of: search-space-wide (``user_id`` and ``thread_id`` NULL), + user-wide (``user_id`` set, ``thread_id`` NULL), or per-thread + (``thread_id`` set). Loaded at agent build time and converted to + :class:`Rule` instances inside the agent factory. + """ + + __tablename__ = "agent_permission_rules" + + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) + permission = Column(String(255), nullable=False) + pattern = Column(String(255), nullable=False, default="*", server_default="*") + action = Column(String(16), nullable=False) # allow / deny / ask + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + __table_args__ = ( + UniqueConstraint( + "search_space_id", + "user_id", + "thread_id", + "permission", + "pattern", + "action", + name="uq_agent_permission_rules_scope", + ), + ) + + class RefreshToken(Base, TimestampMixin): """ Stores refresh tokens for user session management. diff --git a/surfsense_backend/app/observability/__init__.py b/surfsense_backend/app/observability/__init__.py new file mode 100644 index 000000000..dbf082561 --- /dev/null +++ b/surfsense_backend/app/observability/__init__.py @@ -0,0 +1,7 @@ +"""SurfSense observability surface. + +The single user-visible API right now is :mod:`otel`, which exposes a +small wrapper around the optional ``opentelemetry`` instrumentation. The +wrapper is a no-op when OTEL is not configured, so importing it from +performance-critical paths is safe. +""" diff --git a/surfsense_backend/app/observability/otel.py b/surfsense_backend/app/observability/otel.py new file mode 100644 index 000000000..0229524f2 --- /dev/null +++ b/surfsense_backend/app/observability/otel.py @@ -0,0 +1,319 @@ +""" +OpenTelemetry instrumentation helpers for the SurfSense agent stack. + +Tier 3b in the OpenCode-port plan. + +Goals +===== + +- Provide one tiny, ergonomic API for the spans listed in the plan + (``tool.call``, ``model.call``, ``kb.search``, ``kb.persist``, + ``compaction.run``, ``interrupt.raised``, ``permission.asked``). +- Keep span **names** low-cardinality (``tool.call`` rather than + ``tool.call.``); tool name lives in the ``tool.name`` attribute + so dashboards aggregate cleanly. +- Default to **no-op** behavior unless ``OTEL_EXPORTER_OTLP_ENDPOINT`` is + set, OR an external SDK has installed a real ``TracerProvider`` already + (e.g. via the ``opentelemetry-instrument`` agent). +- Coexist with LangSmith: we never disable LangSmith tracing; we add OTel + alongside. +- Gracefully degrade if the ``opentelemetry-api`` package is missing. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Lazy/optional OpenTelemetry import +# ----------------------------------------------------------------------------- + +try: + from opentelemetry import trace as _ot_trace + from opentelemetry.trace import ( + Span as _OtSpan, + Status as _OtStatus, + StatusCode as _OtStatusCode, + ) + + _OTEL_AVAILABLE = True +except ImportError: # pragma: no cover — optional dep + _ot_trace = None # type: ignore[assignment] + _OtSpan = Any # type: ignore[assignment, misc] + _OtStatus = Any # type: ignore[assignment, misc] + _OtStatusCode = Any # type: ignore[assignment, misc] + _OTEL_AVAILABLE = False + + +_INSTRUMENTATION_NAME = "surfsense.new_chat" +_INSTRUMENTATION_VERSION = "0.1.0" + + +# ----------------------------------------------------------------------------- +# Configuration +# ----------------------------------------------------------------------------- + + +def _resolve_enabled() -> bool: + """Return True if OTel spans should actually be emitted.""" + if not _OTEL_AVAILABLE: + return False + # Honor an explicit kill-switch first. + if os.environ.get("SURFSENSE_DISABLE_OTEL", "").lower() in {"1", "true", "yes"}: + return False + # Treat a configured endpoint as the canonical "OTel is wired up" signal. + if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): + return True + # Or honor an external SDK that already installed a non-default TracerProvider. + if _ot_trace is not None: + try: + provider = _ot_trace.get_tracer_provider() + # The default proxy provider has no real exporter wired up. + type_name = type(provider).__name__ + if type_name not in {"ProxyTracerProvider", "NoOpTracerProvider"}: + return True + except Exception: # pragma: no cover — defensive + return False + return False + + +_ENABLED: bool = _resolve_enabled() + + +def is_enabled() -> bool: + """Return True if instrumentation is actively emitting spans.""" + return _ENABLED + + +def _get_tracer(): + if not _OTEL_AVAILABLE: + return None + try: + return _ot_trace.get_tracer(_INSTRUMENTATION_NAME, _INSTRUMENTATION_VERSION) + except Exception: # pragma: no cover — defensive + return None + + +# ----------------------------------------------------------------------------- +# No-op span used when OTel is disabled (avoids a None check at every call site) +# ----------------------------------------------------------------------------- + + +class _NoopSpan: + """A lightweight stand-in that mimics the subset of ``Span`` we use.""" + + def set_attribute(self, key: str, value: Any) -> None: + return None + + def set_attributes(self, attributes: dict[str, Any]) -> None: + return None + + def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None: + return None + + def record_exception(self, exception: BaseException) -> None: + return None + + def set_status(self, status: Any) -> None: + return None + + +# ----------------------------------------------------------------------------- +# Public span helpers +# ----------------------------------------------------------------------------- + + +@contextmanager +def span( + name: str, + *, + attributes: dict[str, Any] | None = None, +) -> Iterator[Any]: + """Generic span context manager. + + Yields the underlying span (or a :class:`_NoopSpan` when disabled) + so callers can attach attributes/events incrementally. + + On exception, the span records the error via :meth:`record_exception` + and sets ``StatusCode.ERROR``; the exception is then re-raised. + """ + if not _ENABLED: + yield _NoopSpan() + return + + tracer = _get_tracer() + if tracer is None: # pragma: no cover — defensive + yield _NoopSpan() + return + + with tracer.start_as_current_span(name) as sp: + if attributes: + try: + sp.set_attributes(attributes) + except Exception: # pragma: no cover — defensive + pass + try: + yield sp + except BaseException as exc: + try: + sp.record_exception(exc) + sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc))) + except Exception: # pragma: no cover — defensive + pass + raise + + +# ----------------------------------------------------------------------------- +# Domain-specific shortcuts (mirror the plan's enumerated span list) +# ----------------------------------------------------------------------------- + + +def tool_call_span( + tool_name: str, + *, + input_size: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span for an individual tool execution. + + Span name is the constant ``tool.call`` (low-cardinality); the tool + identifier lives in the ``tool.name`` attribute. + """ + attrs: dict[str, Any] = {"tool.name": tool_name} + if input_size is not None: + attrs["tool.input.size"] = int(input_size) + if extra: + attrs.update(extra) + return span("tool.call", attributes=attrs) + + +def model_call_span( + *, + model_id: str | None = None, + provider: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around a single ``astream`` / ``ainvoke`` call to the LLM.""" + attrs: dict[str, Any] = {} + if model_id: + attrs["model.id"] = model_id + if provider: + attrs["model.provider"] = provider + if extra: + attrs.update(extra) + return span("model.call", attributes=attrs) + + +def kb_search_span( + *, + search_space_id: int | None = None, + query_chars: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around knowledge-base search routines.""" + attrs: dict[str, Any] = {} + if search_space_id is not None: + attrs["search_space.id"] = int(search_space_id) + if query_chars is not None: + attrs["query.chars"] = int(query_chars) + if extra: + attrs.update(extra) + return span("kb.search", attributes=attrs) + + +def kb_persist_span( + *, + document_type: str | None = None, + document_id: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around knowledge-base persistence operations (NOTE/EXTENSION/FILE).""" + attrs: dict[str, Any] = {} + if document_type: + attrs["document.type"] = document_type + if document_id is not None: + attrs["document.id"] = int(document_id) + if extra: + attrs.update(extra) + return span("kb.persist", attributes=attrs) + + +def compaction_span( + *, + reason: str | None = None, + messages_in: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around the compaction (summarization) middleware run.""" + attrs: dict[str, Any] = {} + if reason: + attrs["compaction.reason"] = reason + if messages_in is not None: + attrs["compaction.messages.in"] = int(messages_in) + if extra: + attrs.update(extra) + return span("compaction.run", attributes=attrs) + + +def interrupt_span( + *, + interrupt_type: str, + extra: dict[str, Any] | None = None, +): + """Span recording an interrupt being raised (HITL or permission_ask).""" + attrs: dict[str, Any] = {"interrupt.type": interrupt_type} + if extra: + attrs.update(extra) + return span("interrupt.raised", attributes=attrs) + + +def permission_asked_span( + *, + permission: str, + pattern: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span recording a permission ask (PermissionMiddleware).""" + attrs: dict[str, Any] = {"permission.permission": permission} + if pattern: + attrs["permission.pattern"] = pattern + if extra: + attrs.update(extra) + return span("permission.asked", attributes=attrs) + + +# ----------------------------------------------------------------------------- +# Test/utility hooks +# ----------------------------------------------------------------------------- + + +def reload_for_tests() -> bool: + """Re-evaluate :data:`_ENABLED` from the current environment. + + Tests that toggle ``OTEL_EXPORTER_OTLP_ENDPOINT`` or + ``SURFSENSE_DISABLE_OTEL`` can call this to reset cached state. + Returns the new value of :func:`is_enabled`. + """ + global _ENABLED + _ENABLED = _resolve_enabled() + return _ENABLED + + +__all__ = [ + "compaction_span", + "interrupt_span", + "is_enabled", + "kb_persist_span", + "kb_search_span", + "model_call_span", + "permission_asked_span", + "reload_for_tests", + "span", + "tool_call_span", +] diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index de4e05423..a6a95ad30 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -1,5 +1,9 @@ from fastapi import APIRouter +from .agent_action_log_route import router as agent_action_log_router +from .agent_flags_route import router as agent_flags_router +from .agent_permissions_route import router as agent_permissions_router +from .agent_revert_route import router as agent_revert_router from .airtable_add_connector_route import ( router as airtable_add_connector_router, ) @@ -66,6 +70,12 @@ router.include_router(documents_router) router.include_router(folders_router) router.include_router(notes_router) router.include_router(new_chat_router) # Chat with assistant-ui persistence +router.include_router(agent_revert_router) # POST /threads/{id}/revert/{action_id} +router.include_router(agent_action_log_router) # GET /threads/{id}/actions +router.include_router( + agent_permissions_router +) # CRUD for /searchspaces/{id}/agent/permissions/rules +router.include_router(agent_flags_router) # GET /agent/flags router.include_router(sandbox_router) # Sandbox file downloads (Daytona) router.include_router(chat_comments_router) router.include_router(podcasts_router) # Podcast task status and audio diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py new file mode 100644 index 000000000..458635761 --- /dev/null +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -0,0 +1,186 @@ +"""``GET /api/threads/{thread_id}/actions``: list agent action-log entries. + +Pairs with ``POST /api/threads/{thread_id}/revert/{action_id}`` (see +``agent_revert_route.py``). The action log is the read-side surface for +the audit/undo UI: it returns a paginated list of every tool call +recorded by :class:`ActionLogMiddleware` against the thread, plus +metadata about whether the action is reversible and whether it has +already been reverted. + +The route is gated by the same ``SURFSENSE_ENABLE_ACTION_LOG`` flag that +controls the middleware. When the flag is off the endpoint returns 503 +so the UI can detect "this deployment doesn't have the action log +enabled" without 404-ing on a missing route. + +The list is ordered DESC by ``created_at`` (newest first) so the +revert UI can render a familiar reverse-chronological feed without an +additional client-side sort. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.feature_flags import get_flags +from app.db import ( + AgentActionLog, + NewChatThread, + Permission, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Response schemas +# --------------------------------------------------------------------------- + + +class AgentActionRead(BaseModel): + """One row of the action log surfaced to the client.""" + + id: int + thread_id: int + user_id: str | None + search_space_id: int + tool_name: str + args: dict[str, Any] | None + result_id: str | None + reversible: bool + reverse_descriptor: dict[str, Any] | None + error: dict[str, Any] | None + reverse_of: int | None + reverted_by_action_id: int | None + is_revert_action: bool + created_at: datetime + + +class AgentActionListResponse(BaseModel): + """Paginated list response for the action log.""" + + items: list[AgentActionRead] + total: int + page: int + page_size: int + has_more: bool + + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- + + +def _flag_guard() -> None: + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_action_log: + raise HTTPException( + status_code=503, + detail=( + "Action log is not available on this deployment. Flip " + "SURFSENSE_ENABLE_ACTION_LOG to enable it." + ), + ) + + +@router.get( + "/threads/{thread_id}/actions", + response_model=AgentActionListResponse, +) +async def list_thread_actions( + thread_id: int, + page: int = Query(0, ge=0), + page_size: int = Query(50, ge=1, le=200), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AgentActionListResponse: + """List agent actions for a thread, newest first. + + Authorization: + * Caller must be a member of the thread's search space with + ``CHATS_READ`` permission. + + Pagination: + * ``page`` is 0-indexed. + * ``page_size`` defaults to 50, max 200. + """ + + _flag_guard() + + thread = await session.get(NewChatThread, thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view this thread's action log.", + ) + + total_stmt = select(func.count(AgentActionLog.id)).where( + AgentActionLog.thread_id == thread_id + ) + total = (await session.execute(total_stmt)).scalar_one() + + rows_stmt = ( + select(AgentActionLog) + .where(AgentActionLog.thread_id == thread_id) + .order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc()) + .offset(page * page_size) + .limit(page_size) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + # Build a reverse_of -> revert_action_id map so the UI can render + # "Reverted" badges on actions that have already been undone. + if rows: + original_ids = [r.id for r in rows] + reverts_stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where( + AgentActionLog.reverse_of.in_(original_ids) + ) + reverts = (await session.execute(reverts_stmt)).all() + revert_map: dict[int, int] = {orig: rev for rev, orig in reverts} + else: + revert_map = {} + + items = [ + AgentActionRead( + id=row.id, + thread_id=row.thread_id, + user_id=str(row.user_id) if row.user_id is not None else None, + search_space_id=row.search_space_id, + tool_name=row.tool_name, + args=row.args, + result_id=row.result_id, + reversible=bool(row.reversible), + reverse_descriptor=row.reverse_descriptor, + error=row.error, + reverse_of=row.reverse_of, + reverted_by_action_id=revert_map.get(row.id), + is_revert_action=row.reverse_of is not None, + created_at=row.created_at, + ) + for row in rows + ] + + return AgentActionListResponse( + items=items, + total=int(total), + page=page, + page_size=page_size, + has_more=(page + 1) * page_size < int(total), + ) diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py new file mode 100644 index 000000000..d3c90a58d --- /dev/null +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -0,0 +1,71 @@ +"""``GET /api/agent/flags``: read-only feature-flag status. + +Surfaces :class:`AgentFeatureFlags` to the frontend so the UI can: + +* Render conditional surfaces (e.g. show the action-log button only when + ``enable_action_log`` is on). +* Display an admin diagnostics card so operators can verify which + middleware tier is active without shelling into the box. + +The endpoint is *read-only*. Flipping flags requires an env-var change +plus a process restart — by design, since the values are baked into the +agent factory at build time. The route does not require any special +permission (any authenticated user can see them) since the flag values +do not leak data, and the UI surfaces are conditionally rendered based +on them anyway. +""" + +from __future__ import annotations + +from dataclasses import asdict + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags +from app.db import User +from app.users import current_active_user + +router = APIRouter() + + +class AgentFeatureFlagsRead(BaseModel): + """Mirror of :class:`AgentFeatureFlags`. Updated together with it.""" + + disable_new_agent_stack: bool + + enable_context_editing: bool + enable_compaction_v2: bool + enable_retry_after: bool + enable_model_fallback: bool + enable_model_call_limit: bool + enable_tool_call_limit: bool + enable_tool_call_repair: bool + enable_doom_loop: bool + + enable_permission: bool + enable_busy_mutex: bool + enable_llm_tool_selector: bool + + enable_skills: bool + enable_specialized_subagents: bool + enable_kb_planner_runnable: bool + + enable_action_log: bool + enable_revert_route: bool + + enable_plugin_loader: bool + + enable_otel: bool + + @classmethod + def from_flags(cls, flags: AgentFeatureFlags) -> "AgentFeatureFlagsRead": + # asdict() avoids missing-field bugs when AgentFeatureFlags grows. + return cls(**asdict(flags)) + + +@router.get("/agent/flags", response_model=AgentFeatureFlagsRead) +async def get_agent_flags( + _user: User = Depends(current_active_user), +) -> AgentFeatureFlagsRead: + return AgentFeatureFlagsRead.from_flags(get_flags()) diff --git a/surfsense_backend/app/routes/agent_permissions_route.py b/surfsense_backend/app/routes/agent_permissions_route.py new file mode 100644 index 000000000..e87af29c7 --- /dev/null +++ b/surfsense_backend/app/routes/agent_permissions_route.py @@ -0,0 +1,280 @@ +"""CRUD for :class:`app.db.AgentPermissionRule`. + +Surfaces the permission rules consumed by +:class:`PermissionMiddleware`. Rules are scoped at one of three levels: + +* **Search-space wide** — both ``user_id`` and ``thread_id`` are NULL. +* **Per-user** — ``user_id`` set, ``thread_id`` NULL. +* **Per-thread** — ``thread_id`` set (``user_id`` typically NULL). + +The middleware reads these rows at agent build time (see +``chat_deepagent.py``). UI lets a search-space owner curate them so +the agent can ask for approval / auto-deny / auto-allow specific +tool patterns. + +The route group is gated by ``SURFSENSE_ENABLE_PERMISSION``: when off +all endpoints return 503 so the UI can render a "feature not enabled" +empty state without breaking on a missing route. +""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime +from typing import Literal + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.feature_flags import get_flags +from app.db import ( + AgentPermissionRule, + NewChatThread, + Permission, + SearchSpace, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Schemas +# --------------------------------------------------------------------------- + + +_ACTION_VALUES: tuple[str, ...] = ("allow", "deny", "ask") +_PERMISSION_PATTERN = re.compile(r"^[a-zA-Z0-9_:.\-*]+$") + + +class AgentPermissionRuleRead(BaseModel): + id: int + search_space_id: int + user_id: str | None + thread_id: int | None + permission: str + pattern: str + action: Literal["allow", "deny", "ask"] + created_at: datetime + + +class AgentPermissionRuleCreate(BaseModel): + permission: str = Field( + ..., + min_length=1, + max_length=255, + description="Tool / capability the rule targets, e.g. 'tool:create_linear_issue'.", + ) + pattern: str = Field( + "*", + min_length=1, + max_length=255, + description="Wildcard pattern (e.g. '*' or 'production-*') applied to the matched tool argument.", + ) + action: Literal["allow", "deny", "ask"] + user_id: str | None = None + thread_id: int | None = None + + +class AgentPermissionRuleUpdate(BaseModel): + pattern: str | None = Field(default=None, min_length=1, max_length=255) + action: Literal["allow", "deny", "ask"] | None = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _flag_guard() -> None: + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_permission: + raise HTTPException( + status_code=503, + detail=( + "Agent permission rules are not enabled on this deployment. " + "Flip SURFSENSE_ENABLE_PERMISSION to enable them." + ), + ) + + +def _validate_permission_string(value: str) -> str: + if not _PERMISSION_PATTERN.match(value): + raise HTTPException( + status_code=400, + detail=( + "permission must contain only alphanumerics, '.', '_', ':', '-', " + "or '*' wildcards." + ), + ) + return value + + +def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead: + return AgentPermissionRuleRead( + id=row.id, + search_space_id=row.search_space_id, + user_id=str(row.user_id) if row.user_id is not None else None, + thread_id=row.thread_id, + permission=row.permission, + pattern=row.pattern, + action=row.action, # type: ignore[arg-type] + created_at=row.created_at, + ) + + +async def _ensure_search_space_membership_admin( + session: AsyncSession, user: User, search_space_id: int +) -> None: + """Curating agent rules == "settings" administration on the space.""" + space = await session.get(SearchSpace, search_space_id) + if space is None: + raise HTTPException(status_code=404, detail="Search space not found.") + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_UPDATE.value, + "You don't have permission to manage agent permission rules in this space.", + ) + + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- + + +@router.get( + "/searchspaces/{search_space_id}/agent/permissions/rules", + response_model=list[AgentPermissionRuleRead], +) +async def list_rules( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> list[AgentPermissionRuleRead]: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + stmt = ( + select(AgentPermissionRule) + .where(AgentPermissionRule.search_space_id == search_space_id) + .order_by(AgentPermissionRule.created_at.desc(), AgentPermissionRule.id.desc()) + ) + rows = (await session.execute(stmt)).scalars().all() + return [_to_read(r) for r in rows] + + +@router.post( + "/searchspaces/{search_space_id}/agent/permissions/rules", + response_model=AgentPermissionRuleRead, + status_code=201, +) +async def create_rule( + search_space_id: int, + payload: AgentPermissionRuleCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AgentPermissionRuleRead: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + permission = _validate_permission_string(payload.permission.strip()) + pattern = payload.pattern.strip() or "*" + + if payload.thread_id is not None: + thread = await session.get(NewChatThread, payload.thread_id) + if thread is None or thread.search_space_id != search_space_id: + raise HTTPException( + status_code=404, + detail="Thread not found in this search space.", + ) + + row = AgentPermissionRule( + search_space_id=search_space_id, + user_id=payload.user_id, + thread_id=payload.thread_id, + permission=permission, + pattern=pattern, + action=payload.action, + ) + session.add(row) + try: + await session.commit() + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=409, + detail=( + "An identical rule already exists for this scope. Update the " + "existing rule instead." + ), + ) + await session.refresh(row) + return _to_read(row) + + +@router.patch( + "/searchspaces/{search_space_id}/agent/permissions/rules/{rule_id}", + response_model=AgentPermissionRuleRead, +) +async def update_rule( + search_space_id: int, + rule_id: int, + payload: AgentPermissionRuleUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AgentPermissionRuleRead: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + row = await session.get(AgentPermissionRule, rule_id) + if row is None or row.search_space_id != search_space_id: + raise HTTPException(status_code=404, detail="Rule not found.") + + if payload.pattern is not None: + row.pattern = payload.pattern.strip() or "*" + if payload.action is not None: + row.action = payload.action + + try: + await session.commit() + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=409, + detail="Update would create a duplicate rule for this scope.", + ) + await session.refresh(row) + return _to_read(row) + + +@router.delete( + "/searchspaces/{search_space_id}/agent/permissions/rules/{rule_id}", + status_code=204, +) +async def delete_rule( + search_space_id: int, + rule_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> None: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + row = await session.get(AgentPermissionRule, rule_id) + if row is None or row.search_space_id != search_space_id: + raise HTTPException(status_code=404, detail="Rule not found.") + + await session.delete(row) + await session.commit() + return None diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py new file mode 100644 index 000000000..2f6fe6a32 --- /dev/null +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -0,0 +1,122 @@ +"""POST ``/api/threads/{thread_id}/revert/{action_id}``: undo an agent action. + +Per the Tier 5 plan, the route ships **before** the UI lights up the per-message +"Undo from here" affordance. To prevent accidental usage during the gap we +return ``503 Service Unavailable`` until the +``SURFSENSE_ENABLE_REVERT_ROUTE`` flag flips. Once enabled, the route runs: + +1. Authentication via :func:`current_active_user`. +2. Action lookup; 404 if the action does not belong to the thread. +3. Authorization via :func:`app.services.revert_service.can_revert`. +4. Revert dispatch via :func:`app.services.revert_service.revert_action`. +5. Idempotent on retries: if the same action is reverted twice the second + call returns 409 ``"already reverted"``. +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.feature_flags import get_flags +from app.db import ( + AgentActionLog, + User, + get_async_session, +) +from app.services.revert_service import ( + RevertOutcome, + can_revert, + load_action, + load_thread, + revert_action, +) +from app.users import current_active_user + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/threads/{thread_id}/revert/{action_id}") +async def revert_agent_action( + thread_id: int, + action_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> dict: + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_revert_route: + raise HTTPException( + status_code=503, + detail=( + "Revert is not available on this deployment yet. The route " + "ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to " + "enable it." + ), + ) + + thread = await load_thread(session, thread_id=thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + action = await load_action(session, action_id=action_id, thread_id=thread_id) + if action is None: + raise HTTPException( + status_code=404, + detail="Action not found or does not belong to this thread.", + ) + + # Idempotency: if a successful revert already exists, return 409. + existing_revert = await session.execute( + select(AgentActionLog).where(AgentActionLog.reverse_of == action.id) + ) + if existing_revert.scalars().first() is not None: + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) + + if not can_revert( + requester_user_id=str(user.id) if user is not None else None, + action=action, + is_admin=False, # role lookup is done by RBAC layer; default conservative + ): + raise HTTPException( + status_code=403, + detail="You are not allowed to revert this action.", + ) + + outcome: RevertOutcome + try: + outcome = await revert_action( + session, + action=action, + requester_user_id=str(user.id) if user is not None else None, + ) + except Exception: + logger.exception("Revert dispatch raised for action_id=%s", action_id) + await session.rollback() + raise HTTPException(status_code=500, detail="Internal error during revert.") + + if outcome.status == "ok": + await session.commit() + return { + "status": "ok", + "message": outcome.message, + "new_action_id": outcome.new_action_id, + } + + await session.rollback() + + if outcome.status == "not_found" or outcome.status == "tool_unavailable": + raise HTTPException(status_code=409, detail=outcome.message) + if outcome.status == "permission_denied": + raise HTTPException(status_code=403, detail=outcome.message) + if outcome.status == "reverse_not_implemented": + raise HTTPException(status_code=501, detail=outcome.message) + # not_reversible + raise HTTPException(status_code=409, detail=outcome.message) diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py new file mode 100644 index 000000000..e072f90c6 --- /dev/null +++ b/surfsense_backend/app/services/revert_service.py @@ -0,0 +1,279 @@ +"""Revert service for the SurfSense agent action log. + +Implements the actual revert workflow used by +``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a +thin auth + flag wrapper around the functions defined here. + +Operation outcomes mirror the plan: + +* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from + :class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows + written before the original mutation. +* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke + the inverse tool through the agent's normal permission stack (NOT + bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``. +* **Anything else** (deprecated tool / no descriptor / schema drift): + returns ``NOT_REVERSIBLE`` and the route surfaces it as 409. + +A successful revert appends a NEW row to ``agent_action_log`` with +``reverse_of=`` and the requesting user's +``user_id``, preserving an auditable chain. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Literal + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + AgentActionLog, + DocumentRevision, + FolderRevision, + NewChatThread, +) + +logger = logging.getLogger(__name__) + + +RevertOutcomeStatus = Literal[ + "ok", + "not_reversible", + "not_found", + "permission_denied", + "tool_unavailable", + "reverse_not_implemented", +] + + +@dataclass +class RevertOutcome: + """Structured result of :func:`revert_action`.""" + + status: RevertOutcomeStatus + message: str + new_action_id: int | None = None + + +# --------------------------------------------------------------------------- +# Lookup helpers +# --------------------------------------------------------------------------- + + +async def load_action( + session: AsyncSession, + *, + action_id: int, + thread_id: int, +) -> AgentActionLog | None: + """Load the action_log row for ``action_id`` if it belongs to the thread.""" + stmt = select(AgentActionLog).where( + AgentActionLog.id == action_id, + AgentActionLog.thread_id == thread_id, + ) + result = await session.execute(stmt) + return result.scalars().first() + + +async def load_thread( + session: AsyncSession, *, thread_id: int +) -> NewChatThread | None: + stmt = select(NewChatThread).where(NewChatThread.id == thread_id) + result = await session.execute(stmt) + return result.scalars().first() + + +# --------------------------------------------------------------------------- +# Authorization +# --------------------------------------------------------------------------- + + +def can_revert( + *, + requester_user_id: str | None, + action: AgentActionLog, + is_admin: bool, +) -> bool: + """Return True iff the requester is allowed to revert this action. + + The plan's rule: "requester must be the original `user_id` on the + action, or hold the search-space admin role." Anonymous actions + (``action.user_id is None``) can only be reverted by admins. + """ + if is_admin: + return True + if action.user_id is None: + return False + return str(action.user_id) == str(requester_user_id) + + +# --------------------------------------------------------------------------- +# Revert paths +# --------------------------------------------------------------------------- + + +async def _restore_document_revision( + session: AsyncSession, *, action: AgentActionLog +) -> RevertOutcome: + """Restore the most recent :class:`DocumentRevision` for ``action``.""" + stmt = ( + select(DocumentRevision) + .where(DocumentRevision.agent_action_id == action.id) + .order_by(DocumentRevision.created_at.desc()) + .limit(1) + ) + result = await session.execute(stmt) + revision = result.scalars().first() + if revision is None: + return RevertOutcome( + status="not_reversible", + message="No document_revisions row tied to this action.", + ) + + from app.db import Document # late import to avoid cycles at module load + + doc = await session.get(Document, revision.document_id) + if doc is None: + return RevertOutcome( + status="tool_unavailable", + message="Original document has been deleted; revert cannot proceed.", + ) + + if revision.content_before is not None: + doc.content = revision.content_before + if revision.title_before is not None: + doc.title = revision.title_before + if revision.folder_id_before is not None: + doc.folder_id = revision.folder_id_before + doc.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Document restored from snapshot.") + + +async def _restore_folder_revision( + session: AsyncSession, *, action: AgentActionLog +) -> RevertOutcome: + stmt = ( + select(FolderRevision) + .where(FolderRevision.agent_action_id == action.id) + .order_by(FolderRevision.created_at.desc()) + .limit(1) + ) + result = await session.execute(stmt) + revision = result.scalars().first() + if revision is None: + return RevertOutcome( + status="not_reversible", + message="No folder_revisions row tied to this action.", + ) + + from app.db import Folder + + folder = await session.get(Folder, revision.folder_id) + if folder is None: + return RevertOutcome( + status="tool_unavailable", + message="Original folder has been deleted; revert cannot proceed.", + ) + + if revision.name_before is not None: + folder.name = revision.name_before + if revision.parent_id_before is not None: + folder.parent_id = revision.parent_id_before + if revision.position_before is not None: + folder.position = revision.position_before + folder.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Folder restored from snapshot.") + + +# Tool-name prefixes that route to KB document / folder revert paths. Kept +# as data so a future PR adding new KB-owned tools doesn't have to touch +# this module's control flow. +_DOC_TOOL_PREFIXES: tuple[str, ...] = ( + "edit_file", + "write_file", + "update_memory", + "create_note", + "update_note", + "delete_note", +) +_FOLDER_TOOL_PREFIXES: tuple[str, ...] = ( + "mkdir", + "move_file", + "rename_folder", + "delete_folder", +) + + +async def revert_action( + session: AsyncSession, + *, + action: AgentActionLog, + requester_user_id: str | None, +) -> RevertOutcome: + """Execute the revert for ``action`` and return a structured outcome. + + The function does **not** commit — the caller is expected to commit on + success or roll back on failure. A new ``agent_action_log`` row is + added to the session on success with ``reverse_of=action.id``. + """ + tool_name = (action.tool_name or "").lower() + + if tool_name.startswith(_DOC_TOOL_PREFIXES): + outcome = await _restore_document_revision(session, action=action) + elif tool_name.startswith(_FOLDER_TOOL_PREFIXES): + outcome = await _restore_folder_revision(session, action=action) + elif action.reverse_descriptor: + # Connector-owned reversibles run through the normal permission + # stack; out of scope for this PR — the route returns 503 anyway + # until UI ships, so 501-style "not implemented" is fine. + return RevertOutcome( + status="reverse_not_implemented", + message=( + "Connector-action revert is not yet implemented. The " + "reverse_descriptor is stored; future work will replay it " + "through PermissionMiddleware." + ), + ) + else: + return RevertOutcome( + status="not_reversible", + message=( + f"Tool {action.tool_name!r} is not reversible: no document " + "revision and no reverse_descriptor." + ), + ) + + if outcome.status != "ok": + return outcome + + new_row = AgentActionLog( + thread_id=action.thread_id, + user_id=requester_user_id, + search_space_id=action.search_space_id, + turn_id=None, + message_id=None, + tool_name=f"_revert:{action.tool_name}", + args={"reverted_action_id": action.id}, + result_id=None, + reversible=False, + reverse_descriptor=None, + error=None, + reverse_of=action.id, + ) + session.add(new_row) + await session.flush() + outcome.new_action_id = new_row.id + return outcome + + +__all__ = [ + "RevertOutcome", + "can_revert", + "load_action", + "load_thread", + "revert_action", +] diff --git a/surfsense_backend/app/utils/async_retry.py b/surfsense_backend/app/utils/async_retry.py index a56f6550a..607b7a156 100644 --- a/surfsense_backend/app/utils/async_retry.py +++ b/surfsense_backend/app/utils/async_retry.py @@ -33,7 +33,7 @@ F = TypeVar("F", bound=Callable) def _is_retryable(exc: BaseException) -> bool: if isinstance(exc, ConnectorError): return exc.retryable - return bool(isinstance(exc, (httpx.TimeoutException, httpx.ConnectError))) + return bool(isinstance(exc, httpx.TimeoutException | httpx.ConnectError)) def build_retry( diff --git a/surfsense_backend/tests/integration/harness/__init__.py b/surfsense_backend/tests/integration/harness/__init__.py new file mode 100644 index 000000000..9a7ec07dc --- /dev/null +++ b/surfsense_backend/tests/integration/harness/__init__.py @@ -0,0 +1,146 @@ +""" +Integration test harness for the SurfSense agent stack. + +The plan calls for an ``LLMToolEmulator``-backed harness for end-to-end +replay of ``stream_new_chat``. The currently-installed langchain version +does not expose ``LLMToolEmulator``, so this harness builds the equivalent +on top of :class:`langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel`. + +The harness lets a test author script a sequence of model responses +(text + optional tool calls) and replay them against the new_chat agent +graph. Tools are stubbed via ``StubToolSpec`` -> ``langchain_core.tools.tool`` +decorator and execute deterministic Python callbacks. + +Used by: +- ``tests/integration/agents/new_chat/test_feature_flag_smoke.py`` to + confirm the kill-switch path produces identical-shape output regardless + of which middleware flags are toggled. +- Future per-tier PRs to record golden transcripts. +""" + +from __future__ import annotations + +import uuid +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.fake_chat_models import ( + FakeMessagesListChatModel, +) +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool, tool + + +class _ToolBindingFakeChatModel(FakeMessagesListChatModel): + """Adapter so the harness model can pretend it understands ``bind_tools``. + + The base ``FakeMessagesListChatModel`` raises ``NotImplementedError`` from + ``bind_tools``, but ``langchain.agents.create_agent`` always calls + ``bind_tools`` to attach the tool registry. We don't actually need the + fake to honor the tool schema — it's already scripted to emit the right + tool calls — so we return self. + """ + + def bind_tools( # type: ignore[override] + self, + tools: Sequence[Any], + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + return self + + +@dataclass +class StubToolSpec: + """A test-mode tool: a name, description, and a deterministic body.""" + + name: str + description: str + handler: Callable[..., Any] + args_schema: dict[str, Any] | None = None + + def build(self) -> BaseTool: + """Realize as a `langchain_core.tools.BaseTool`.""" + + @tool(name_or_callable=self.name, description=self.description) + def _stub_tool(**kwargs: Any) -> Any: + return self.handler(**kwargs) + + return _stub_tool + + +@dataclass +class ScriptedTurn: + """One scripted assistant turn. + + `text` is the assistant text (may be empty if pure tool call). + `tool_calls` is a list of dicts ``{name, args, id}``; if non-empty, the + agent will route to those tools and append a follow-up turn. + """ + + text: str = "" + tool_calls: list[dict[str, Any]] = field(default_factory=list) + + +def build_scripted_messages(turns: list[ScriptedTurn]) -> list[BaseMessage]: + """Convert :class:`ScriptedTurn` records to AIMessage payloads.""" + out: list[BaseMessage] = [] + for turn in turns: + tool_calls: list[dict[str, Any]] = [] + for tc in turn.tool_calls: + tool_calls.append( + { + "name": tc["name"], + "args": tc.get("args", {}), + "id": tc.get("id") or f"call_{uuid.uuid4().hex[:8]}", + } + ) + out.append(AIMessage(content=turn.text, tool_calls=tool_calls or [])) + return out + + +@dataclass +class ScriptedHarness: + """Bundle of (model, tools) ready to plug into ``create_agent``.""" + + model: _ToolBindingFakeChatModel + tools: list[BaseTool] + + +def build_scripted_harness( + *, + turns: list[ScriptedTurn], + tools: list[StubToolSpec] | None = None, + sleep: float | None = None, +) -> ScriptedHarness: + """Construct a deterministic agent harness from a script. + + Example:: + + harness = build_scripted_harness( + turns=[ + ScriptedTurn(tool_calls=[{"name": "echo", "args": {"x": 1}}]), + ScriptedTurn(text="done"), + ], + tools=[ + StubToolSpec(name="echo", description="echo args", handler=lambda **kw: kw), + ], + ) + """ + messages = build_scripted_messages(turns) + model = _ToolBindingFakeChatModel(responses=messages, sleep=sleep) + realized_tools = [t.build() for t in (tools or [])] + return ScriptedHarness(model=model, tools=realized_tools) + + +__all__ = [ + "ScriptedHarness", + "ScriptedTurn", + "StubToolSpec", + "build_scripted_harness", + "build_scripted_messages", +] diff --git a/surfsense_backend/tests/integration/harness/test_scripted_harness.py b/surfsense_backend/tests/integration/harness/test_scripted_harness.py new file mode 100644 index 000000000..6e9f7ab91 --- /dev/null +++ b/surfsense_backend/tests/integration/harness/test_scripted_harness.py @@ -0,0 +1,53 @@ +"""Smoke test: scripted harness drives create_agent end-to-end and produces a tool-call-then-final-text trace.""" + +from __future__ import annotations + +import pytest +from langchain.agents import create_agent + +from tests.integration.harness import ( + ScriptedTurn, + StubToolSpec, + build_scripted_harness, +) + +pytestmark = pytest.mark.integration + + +@pytest.mark.asyncio +async def test_scripted_harness_drives_basic_agent() -> None: + harness = build_scripted_harness( + turns=[ + ScriptedTurn( + tool_calls=[ + {"name": "echo", "args": {"x": 1}, "id": "call_1"}, + ] + ), + ScriptedTurn(text="done"), + ], + tools=[ + StubToolSpec( + name="echo", + description="Echo args back.", + handler=lambda **kwargs: {"echoed": kwargs}, + ), + ], + ) + + agent = create_agent( + harness.model, + system_prompt="You are a test agent.", + tools=harness.tools, + ) + + result = await agent.ainvoke({"messages": [("user", "do the thing")]}) + messages = result["messages"] + final_ai = next( + (m for m in reversed(messages) if m.__class__.__name__ == "AIMessage"), + None, + ) + assert final_ai is not None + assert final_ai.content == "done" + tool_messages = [m for m in messages if m.__class__.__name__ == "ToolMessage"] + assert len(tool_messages) == 1 + assert "echoed" in str(tool_messages[0].content) diff --git a/surfsense_backend/tests/unit/agents/__init__.py b/surfsense_backend/tests/unit/agents/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/tests/unit/agents/new_chat/__init__.py b/surfsense_backend/tests/unit/agents/new_chat/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py new file mode 100644 index 000000000..a92d371bd --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py @@ -0,0 +1 @@ +"""__init__ stub so pytest discovers the prompts test module.""" diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py new file mode 100644 index 000000000..d35b7aa8b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -0,0 +1,201 @@ +"""Tests for the prompt fragment composer (Tier 3a).""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from app.agents.new_chat.prompts.composer import ( + ALL_TOOL_NAMES_ORDERED, + compose_system_prompt, + detect_provider_variant, +) +from app.db import ChatVisibility + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def fixed_today() -> datetime: + return datetime(2025, 6, 1, 12, 0, tzinfo=UTC) + + +class TestProviderVariantDetection: + @pytest.mark.parametrize( + "model_name,expected", + [ + ("openai:gpt-4o-mini", "openai_classic"), + ("openai:gpt-4-turbo", "openai_classic"), + ("openai:gpt-5", "openai_reasoning"), + ("openai:gpt-5-codex", "openai_reasoning"), + ("openai:o1-preview", "openai_reasoning"), + ("openai:o3-mini", "openai_reasoning"), + ("anthropic:claude-3-5-sonnet", "anthropic"), + ("anthropic/claude-opus-4", "anthropic"), + ("google:gemini-2.0-flash", "google"), + ("vertex:gemini-1.5-pro", "google"), + ("groq:mixtral-8x7b", "default"), + (None, "default"), + ("", "default"), + ], + ) + def test_detection(self, model_name: str | None, expected: str) -> None: + assert detect_provider_variant(model_name) == expected + + +class TestCompose: + def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt(today=fixed_today) + # System instruction wrapper + assert "" in prompt + assert "" in prompt + # Date interpolated + assert "2025-06-01" in prompt + # Core policy blocks present + assert "" in prompt + assert "" in prompt + assert "" in prompt + assert "" in prompt + # Tools + assert "" in prompt + assert "" in prompt + # Citations on by default + assert "" in prompt + assert "[citation:chunk_id]" in prompt + + def test_team_visibility_uses_team_variants( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + thread_visibility=ChatVisibility.SEARCH_SPACE, + ) + # Team-specific phrasing in the agent block + assert "team space" in prompt + # Memory protocol mentions team + assert "team" in prompt + # Should NOT mention the user-only memory phrasing + assert "personal knowledge base" not in prompt + + def test_private_visibility_uses_private_variants( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + thread_visibility=ChatVisibility.PRIVATE, + ) + assert "personal knowledge base" in prompt + # Should NOT mention the team-specific phrasing about prefixed authors + assert "[DisplayName of the author]" not in prompt + + def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None: + prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True) + prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False) + assert "Citations are DISABLED" in prompt_off + assert "Citations are DISABLED" not in prompt_on + assert "[citation:chunk_id]" in prompt_on + + def test_enabled_tool_filter_only_includes_listed_tools( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"web_search", "scrape_webpage"}, + ) + assert "web_search:" in prompt or "- web_search:" in prompt + assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt + # Excluded tools should NOT appear in tool listing + assert "generate_podcast:" not in prompt + assert "generate_image:" not in prompt + + def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"web_search"}, + disabled_tool_names={"generate_image", "generate_podcast"}, + ) + assert "DISABLED TOOLS (by user):" in prompt + assert "Generate Image" in prompt + assert "Generate Podcast" in prompt + + def test_mcp_routing_block_emits_when_provided( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]}, + ) + assert "" in prompt + assert "My GitLab" in prompt + assert "gitlab_search" in prompt + + def test_mcp_routing_block_absent_when_no_servers( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={}) + assert "" not in prompt + + def test_provider_block_renders_when_anthropic( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, model_name="anthropic:claude-3-5-sonnet" + ) + assert "" in prompt + assert "Anthropic" in prompt or "Claude" in prompt + + def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo") + assert "" not in prompt + + def test_custom_system_instructions_override_default( + self, fixed_today: datetime + ) -> None: + custom = "You are a custom assistant. Today is {resolved_today}." + prompt = compose_system_prompt( + today=fixed_today, custom_system_instructions=custom + ) + assert "You are a custom assistant. Today is 2025-06-01." in prompt + # Default block should NOT be present + assert "" not in prompt + + def test_use_default_false_with_no_custom_yields_no_system_block( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + use_default_system_instructions=False, + ) + # No system_instruction wrapper but tools/citations still emitted + assert "" not in prompt + assert "" in prompt + + def test_all_known_tools_have_fragments(self) -> None: + # Soft assertion: verify that every tool in the canonical order + # produces non-empty content for at least one variant. + for tool in ALL_TOOL_NAMES_ORDERED: + prompt = compose_system_prompt( + today=datetime(2025, 1, 1, tzinfo=UTC), + enabled_tool_names={tool}, + ) + assert tool in prompt, f"tool {tool!r} missing from composed prompt" + + +class TestStableOrderingForCacheStability: + """Regression guard: prompt cache hit-rate depends on byte-stable prefix.""" + + def test_composition_is_deterministic_given_same_inputs( + self, fixed_today: datetime + ) -> None: + a = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"web_search", "scrape_webpage"}, + mcp_connector_tools={"X": ["x_a", "x_b"]}, + ) + b = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter + mcp_connector_tools={"X": ["x_a", "x_b"]}, + ) + assert a == b diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py new file mode 100644 index 000000000..6834b5be7 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -0,0 +1,311 @@ +"""Unit tests for ActionLogMiddleware (Tier 5.2).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware.action_log import ActionLogMiddleware +from app.agents.new_chat.tools.registry import ToolDefinition + + +@dataclass +class _FakeRequest: + """Minimal stand-in for ToolCallRequest used in unit tests.""" + + tool_call: dict[str, Any] + tool: Any = None + state: Any = None + runtime: Any = None + + +@tool +def make_widget(color: str, size: int) -> str: + """Create a widget.""" + return f"made {color} {size}" + + +def _enabled_flags(**overrides: bool) -> AgentFeatureFlags: + return AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + **overrides, + ) + + +def _disabled_flags() -> AgentFeatureFlags: + return AgentFeatureFlags(disable_new_agent_stack=False, enable_action_log=False) + + +@pytest.fixture +def patch_get_flags(): + def _patch(flags: AgentFeatureFlags): + return patch( + "app.agents.new_chat.middleware.action_log.get_flags", + return_value=flags, + ) + + return _patch + + +@pytest.fixture +def fake_session_factory(): + """Patch ``shielded_async_session`` with a recording fake.""" + captured: dict[str, list] = {"rows": []} + + class _FakeSession: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def add(self, row): + captured["rows"].append(row) + + async def commit(self): + captured["committed"] = True + + def _factory(): + return _FakeSession() + + return captured, _factory + + +class TestActionLogMiddlewareDisabled: + @pytest.mark.asyncio + async def test_no_op_when_flag_off(self, patch_get_flags) -> None: + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with patch_get_flags(_disabled_flags()): + result = await mw.awrap_tool_call(request, handler) + handler.assert_awaited_once() + assert isinstance(result, ToolMessage) + + @pytest.mark.asyncio + async def test_no_op_when_thread_id_none(self, patch_get_flags) -> None: + mw = ActionLogMiddleware(thread_id=None, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with patch_get_flags(_enabled_flags()): + result = await mw.awrap_tool_call(request, handler) + assert isinstance(result, ToolMessage) + + +class TestActionLogMiddlewarePersistence: + @pytest.mark.asyncio + async def test_writes_row_on_success( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "red", "size": 3}, + "id": "tc-abc", + }, + ) + result_msg = ToolMessage( + content="ok", tool_call_id="tc-abc", id="msg-1" + ) + handler = AsyncMock(return_value=result_msg) + + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=lambda: factory() + ): + result = await mw.awrap_tool_call(request, handler) + + assert result is result_msg + assert len(captured["rows"]) == 1 + row = captured["rows"][0] + assert row.thread_id == 42 + assert row.search_space_id == 7 + assert row.user_id == "u1" + assert row.tool_name == "make_widget" + assert row.args == {"color": "red", "size": 3} + assert row.result_id == "msg-1" + assert row.error is None + assert row.reverse_descriptor is None + assert row.reversible is False + + @pytest.mark.asyncio + async def test_writes_row_on_failure_and_reraises( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {"color": "red"}, "id": "tc1"} + ) + handler = AsyncMock(side_effect=ValueError("boom")) + + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=lambda: factory() + ), pytest.raises(ValueError, match="boom"): + await mw.awrap_tool_call(request, handler) + + assert len(captured["rows"]) == 1 + row = captured["rows"][0] + assert row.tool_name == "make_widget" + assert row.error == {"type": "ValueError", "message": "boom"} + assert row.result_id is None + + @pytest.mark.asyncio + async def test_persistence_failure_does_not_break_tool_call( + self, patch_get_flags + ) -> None: + """Even if the DB write blows up, the tool's result must reach the model.""" + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc1") + handler = AsyncMock(return_value=result_msg) + + def _exploding_session(): + raise RuntimeError("DB is down") + + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=_exploding_session + ): + result = await mw.awrap_tool_call(request, handler) + assert result is result_msg + + +class TestReverseDescriptor: + @pytest.mark.asyncio + async def test_renders_reverse_descriptor_when_tool_declares_one( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + + def _reverse(args, result): + return {"tool": "delete_widget", "args": {"id": result["id"]}} + + tool_def = ToolDefinition( + name="make_widget", + description="Create a widget", + factory=lambda deps: make_widget, + reverse=_reverse, + ) + mw = ActionLogMiddleware( + thread_id=1, + search_space_id=1, + user_id="u", + tool_definitions={"make_widget": tool_def}, + ) + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "blue", "size": 1}, + "id": "tc-xyz", + }, + ) + result_msg = ToolMessage( + content='{"id": "widget-9"}', tool_call_id="tc-xyz", id="msg-9" + ) + handler = AsyncMock(return_value=result_msg) + + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=lambda: factory() + ): + await mw.awrap_tool_call(request, handler) + + row = captured["rows"][0] + assert row.reversible is True + assert row.reverse_descriptor == { + "tool": "delete_widget", + "args": {"id": "widget-9"}, + } + + @pytest.mark.asyncio + async def test_swallows_reverse_callable_errors( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + + def _bad_reverse(args, result): + raise RuntimeError("reverse blew up") + + tool_def = ToolDefinition( + name="make_widget", + description="Create a widget", + factory=lambda deps: make_widget, + reverse=_bad_reverse, + ) + mw = ActionLogMiddleware( + thread_id=1, + search_space_id=1, + user_id=None, + tool_definitions={"make_widget": tool_def}, + ) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc1") + handler = AsyncMock(return_value=result_msg) + + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=lambda: factory() + ): + await mw.awrap_tool_call(request, handler) + + row = captured["rows"][0] + assert row.reversible is False + assert row.reverse_descriptor is None + + @pytest.mark.asyncio + async def test_no_reverse_when_tool_definition_missing( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"} + ) + handler = AsyncMock( + return_value=ToolMessage(content="ok", tool_call_id="tc1") + ) + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=lambda: factory() + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.reversible is False + + +class TestArgsTruncation: + @pytest.mark.asyncio + async def test_huge_args_payload_is_truncated( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + # Build a > 32KB string so the persisted payload triggers the truncation path. + huge = "x" * (40 * 1024) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"}, + ) + handler = AsyncMock( + return_value=ToolMessage(content="ok", tool_call_id="tc1") + ) + with patch_get_flags(_enabled_flags()), patch( + "app.db.shielded_async_session", side_effect=lambda: factory() + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.args is not None + assert row.args.get("_truncated") is True + assert row.args.get("_size", 0) >= 40 * 1024 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py new file mode 100644 index 000000000..0c7bf17f6 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -0,0 +1,90 @@ +"""Tests for BusyMutexMiddleware: per-thread lock + cancel event behavior.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import ( + BusyMutexMiddleware, + get_cancel_event, + manager, + request_cancel, + reset_cancel, +) + +pytestmark = pytest.mark.unit + + +class _Runtime: + def __init__(self, thread_id: str | None) -> None: + self.config = {"configurable": {"thread_id": thread_id}} + + +@pytest.mark.asyncio +async def test_first_acquire_succeeds_and_release_unblocks() -> None: + mw = BusyMutexMiddleware() + runtime = _Runtime("t1") + await mw.abefore_agent({}, runtime) + + # Lock should now be held + lock = manager.lock_for("t1") + assert lock.locked() + + await mw.aafter_agent({}, runtime) + assert not lock.locked() + + +@pytest.mark.asyncio +async def test_second_concurrent_acquire_raises_busy() -> None: + mw_a = BusyMutexMiddleware() + mw_b = BusyMutexMiddleware() + runtime = _Runtime("t-conflict") + await mw_a.abefore_agent({}, runtime) + + with pytest.raises(BusyError) as excinfo: + await mw_b.abefore_agent({}, runtime) + assert excinfo.value.request_id == "t-conflict" + + await mw_a.aafter_agent({}, runtime) + # After release, mw_b can acquire + await mw_b.abefore_agent({}, runtime) + await mw_b.aafter_agent({}, runtime) + + +@pytest.mark.asyncio +async def test_cancel_event_lifecycle() -> None: + mw = BusyMutexMiddleware() + runtime = _Runtime("t-cancel") + + await mw.abefore_agent({}, runtime) + event = get_cancel_event("t-cancel") + assert not event.is_set() + + request_cancel("t-cancel") + assert event.is_set() + + # End of turn should reset + await mw.aafter_agent({}, runtime) + assert not event.is_set() + + +@pytest.mark.asyncio +async def test_no_thread_id_raises_when_required() -> None: + mw = BusyMutexMiddleware(require_thread_id=True) + runtime = _Runtime(None) + with pytest.raises(BusyError): + await mw.abefore_agent({}, runtime) + + +@pytest.mark.asyncio +async def test_no_thread_id_skipped_when_not_required() -> None: + mw = BusyMutexMiddleware(require_thread_id=False) + runtime = _Runtime(None) + await mw.abefore_agent({}, runtime) + await mw.aafter_agent({}, runtime) + + +def test_reset_cancel_idempotent() -> None: + # Should not raise even if event was never created + reset_cancel("never-seen") diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py new file mode 100644 index 000000000..4d8d6805c --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py @@ -0,0 +1,107 @@ +"""Tests for SurfSenseCompactionMiddleware: protected SystemMessage handling and content sanitization.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +from app.agents.new_chat.middleware.compaction import ( + PROTECTED_SYSTEM_PREFIXES, + _is_protected_system_message, + _sanitize_message_content, +) + +pytestmark = pytest.mark.unit + + +class TestIsProtectedSystemMessage: + @pytest.mark.parametrize("prefix", PROTECTED_SYSTEM_PREFIXES) + def test_each_prefix_protected(self, prefix: str) -> None: + msg = SystemMessage(content=f"{prefix}\nbody\n") + assert _is_protected_system_message(msg) is True + + def test_unprotected_system_message(self) -> None: + assert _is_protected_system_message(SystemMessage(content="random instructions")) is False + + def test_human_message_never_protected(self) -> None: + assert _is_protected_system_message(HumanMessage(content="...")) is False + + def test_tolerates_leading_whitespace(self) -> None: + msg = SystemMessage(content=" \n\n...") + assert _is_protected_system_message(msg) is True + + +class TestSanitizeMessageContent: + def test_returns_same_message_when_content_present(self) -> None: + msg = AIMessage(content="hello") + assert _sanitize_message_content(msg) is msg + + def test_replaces_none_with_empty_string(self) -> None: + # Pydantic blocks ``content=None`` at construction; the real + # crash happens when the streaming layer mutates ``content`` + # after-the-fact. Replicate that by force-setting on a built + # message. + msg = AIMessage( + content="", + tool_calls=[{"name": "x", "args": {}, "id": "1"}], + ) + # Bypass pydantic validation to simulate the LiteLLM/Bedrock case + object.__setattr__(msg, "content", None) + sanitized = _sanitize_message_content(msg) + assert sanitized.content == "" + + +class TestPartitionMessages: + """Verify the partition override surfaces protected SystemMessages + into ``preserved_messages`` regardless of cutoff position. + """ + + def _build_partitioner(self): + # Construct a thin shim — we can't easily instantiate the full + # SurfSenseCompactionMiddleware without a real model, but the + # override path needs ``_lc_helper`` to delegate to. We mock + # that with a simple slicing partitioner equivalent to the real one. + from app.agents.new_chat.middleware.compaction import ( + SurfSenseCompactionMiddleware, + ) + + class _LcHelper: + @staticmethod + def _partition_messages(messages, cutoff): + return messages[:cutoff], messages[cutoff:] + + class _Stub(SurfSenseCompactionMiddleware): + def __init__(self): + self._lc_helper = _LcHelper() + + return _Stub() + + def test_protected_system_message_preserved_even_in_summarize_half(self) -> None: + partitioner = self._build_partitioner() + protected = SystemMessage(content="\n...") + msgs = [ + HumanMessage(content="old human"), + AIMessage(content="old ai"), + protected, + ToolMessage(content="tool 1", tool_call_id="t1"), + HumanMessage(content="new"), + ] + # Cutoff = 4 means everything before index 4 should be summarized + to_summary, preserved = partitioner._partition_messages(msgs, 4) + + assert protected not in to_summary + assert protected in preserved + # The non-protected old messages remain in to_summary + assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary) + + def test_unprotected_messages_unaffected(self) -> None: + partitioner = self._build_partitioner() + msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")] + to_summary, preserved = partitioner._partition_messages(msgs, 2) + assert [m.content for m in to_summary] == ["a", "b"] + assert [m.content for m in preserved] == ["c"] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py new file mode 100644 index 000000000..3c31155d4 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py @@ -0,0 +1,107 @@ +"""Tests for SpillToBackendEdit and SpillingContextEditingMiddleware.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from app.agents.new_chat.middleware.context_editing import ( + SpillToBackendEdit, + _build_spill_placeholder, +) + +pytestmark = pytest.mark.unit + + +def _build_history(num_pairs: int = 6) -> list[Any]: + """Build a long history of (AIMessage with tool_call, ToolMessage) pairs.""" + msgs: list[Any] = [HumanMessage(content="please do many things")] + for i in range(num_pairs): + msgs.append( + AIMessage( + content="", + tool_calls=[ + {"name": f"tool_{i}", "args": {"i": i}, "id": f"call-{i}"}, + ], + ) + ) + msgs.append( + ToolMessage( + content="x" * 5000, + tool_call_id=f"call-{i}", + name=f"tool_{i}", + id=f"tool-msg-{i}", + ) + ) + return msgs + + +def _approx_count(messages: list[Any]) -> int: + """Trivial token counter: 1 token per 4 chars.""" + total = 0 + for msg in messages: + content = getattr(msg, "content", "") + if isinstance(content, str): + total += len(content) // 4 + return total + + +class TestSpillEdit: + def test_below_trigger_does_nothing(self) -> None: + edit = SpillToBackendEdit(trigger=1_000_000, keep=2) + msgs = _build_history(3) + original_lengths = [len(getattr(m, "content", "")) for m in msgs] + edit.apply(msgs, count_tokens=_approx_count) + new_lengths = [len(getattr(m, "content", "")) for m in msgs] + assert original_lengths == new_lengths + assert edit.pending_spills == [] + + def test_above_trigger_clears_and_records(self) -> None: + edit = SpillToBackendEdit(trigger=100, keep=1, path_prefix="/tool_outputs") + msgs = _build_history(4) + edit.apply(msgs, count_tokens=_approx_count) + + # The most-recent ToolMessage (keep=1) should remain intact + tool_messages = [m for m in msgs if isinstance(m, ToolMessage)] + intact = tool_messages[-1] + assert intact.content.startswith("x") # untouched + + # Earlier ToolMessages should now contain the placeholder text + cleared = [ + m for m in tool_messages + if isinstance(m.content, str) and m.content.startswith("[cleared") + ] + assert len(cleared) >= 1 + # And the spill list should match + assert len(edit.pending_spills) == len(cleared) + + def test_excluded_tools_not_cleared(self) -> None: + edit = SpillToBackendEdit( + trigger=100, + keep=0, + exclude_tools=("tool_0",), + ) + msgs = _build_history(4) + edit.apply(msgs, count_tokens=_approx_count) + + first_tool = next( + m for m in msgs if isinstance(m, ToolMessage) and m.name == "tool_0" + ) + # Excluded — untouched + assert first_tool.content.startswith("x") + + def test_drain_clears_pending(self) -> None: + edit = SpillToBackendEdit(trigger=100, keep=1) + msgs = _build_history(4) + edit.apply(msgs, count_tokens=_approx_count) + first_drain = edit.drain_pending() + assert len(first_drain) > 0 + assert edit.drain_pending() == [] + + def test_placeholder_format(self) -> None: + path = "/tool_outputs/thread-1/tool-msg-0.txt" + text = _build_spill_placeholder(path) + assert path in text + assert "explore" in text # mentions the recovery agent diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py new file mode 100644 index 000000000..95017d744 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py @@ -0,0 +1,132 @@ +"""Tests for declarative dedup_key on ToolDefinition (Tier 2.3 migration).""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage +from langchain_core.tools import StructuredTool + +from app.agents.new_chat.middleware.dedup_tool_calls import ( + DedupHITLToolCallsMiddleware, +) + +pytestmark = pytest.mark.unit + + +def _make_tool(name: str, *, dedup_key=None, hitl_dedup_key=None): + metadata = {} + if dedup_key is not None: + metadata["dedup_key"] = dedup_key + if hitl_dedup_key is not None: + metadata["hitl"] = True + metadata["hitl_dedup_key"] = hitl_dedup_key + + def _fn(**kwargs): + return "ok" + + return StructuredTool.from_function( + func=_fn, name=name, description="x", metadata=metadata + ) + + +def _msg(*calls: dict) -> AIMessage: + return AIMessage(content="", tool_calls=list(calls)) + + +class _Runtime: + pass + + +def test_callable_dedup_key_takes_priority() -> None: + tool = _make_tool( + "create_doc", + dedup_key=lambda args: f"{args.get('parent_id')}::{args.get('title')}", + ) + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) + state = { + "messages": [ + _msg( + {"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"}, + {"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"}, + {"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"}, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is not None + new_calls = out["messages"][0].tool_calls + assert len(new_calls) == 2 # one duplicate dropped + assert {c["id"] for c in new_calls} == {"1", "3"} + + +def test_string_hitl_dedup_key_still_works() -> None: + tool = _make_tool("send_x", hitl_dedup_key="subject") + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) + state = { + "messages": [ + _msg( + {"name": "send_x", "args": {"subject": "Hello"}, "id": "1"}, + {"name": "send_x", "args": {"subject": "hello"}, "id": "2"}, # case + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is not None + assert len(out["messages"][0].tool_calls) == 1 + + +def test_no_agent_tools_means_no_dedup() -> None: + """After the cleanup tier removed the legacy ``_NATIVE_HITL_TOOL_DEDUP_KEYS`` + map, dedup is purely declarative — no resolvers means no dedup runs. + + Coverage for the previously hardcoded native HITL tools now lives on + each :class:`ToolDefinition.dedup_key` in + :mod:`app.agents.new_chat.tools.registry`, which is wired through to + ``tool.metadata`` by :func:`build_tools`. + """ + mw = DedupHITLToolCallsMiddleware(agent_tools=None) + state = { + "messages": [ + _msg( + {"name": "create_notion_page", "args": {"title": "X"}, "id": "1"}, + {"name": "create_notion_page", "args": {"title": "x"}, "id": "2"}, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is None + + +def test_registry_propagates_dedup_key_to_tool_metadata() -> None: + """Smoke-check the wiring path that replaced the legacy native map. + + ``ToolDefinition.dedup_key`` set in the registry must be copied onto + the constructed tool's ``metadata`` so :class:`DedupHITLToolCallsMiddleware` + can pick it up at agent build time. + """ + from app.agents.new_chat.tools.registry import ( + BUILTIN_TOOLS, + wrap_dedup_key_by_arg_name, + ) + + notion_tool_defs = [t for t in BUILTIN_TOOLS if t.name == "create_notion_page"] + assert notion_tool_defs, "registry should still expose create_notion_page" + tool_def = notion_tool_defs[0] + assert tool_def.dedup_key is not None + # Same wrapping helper used in the registry — sanity check identity + sample = wrap_dedup_key_by_arg_name("title")({"title": "Plan"}) + assert sample == "plan" + + +def test_unknown_tool_passes_through() -> None: + mw = DedupHITLToolCallsMiddleware(agent_tools=None) + state = { + "messages": [ + _msg( + {"name": "anything_else", "args": {"x": 1}, "id": "1"}, + {"name": "anything_else", "args": {"x": 1}, "id": "2"}, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is None # no dedup configured -> kept diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py new file mode 100644 index 000000000..d49edbfec --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py @@ -0,0 +1,128 @@ +"""Lock in the default-allow layering used by ``chat_deepagent``. + +The agent factory wires ``PermissionMiddleware`` with three rulesets, +earliest -> latest: + +1. ``surfsense_defaults`` (single ``allow */*`` rule) +2. ``connector_synthesized`` (deny rules for tools whose required + connector is missing) +3. (future) user-defined rules from the Agent Permissions UI + +Without #1 every read-only built-in (``ls``, ``read_file``, ``grep``, +``glob``, ``web_search`` …) defaulted to ``ask`` because +``permissions.evaluate`` returns ``ask`` when no rule matches. That +caused two production-painful behaviors: + +* Resume payloads with a prior reject decision bled into innocent + read-only tool calls, raising ``RejectedError("ls")``. +* Mutating connector tools got *double* prompted — once via the + middleware ``ask`` and again via the per-tool ``interrupt()`` in + ``app.agents.new_chat.tools.hitl``. + +These tests pin the layering so a refactor that drops the default +ruleset fails loud. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) + +pytestmark = pytest.mark.unit + + +def _layered_rulesets(connector_denies: list[Rule]) -> list[Ruleset]: + """Replicate ``chat_deepagent`` layering for the test.""" + return [ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + Ruleset(rules=connector_denies, origin="connector_synthesized"), + ] + + +class TestReadOnlyToolsAllowed: + """Read-only built-ins must NOT default to ask.""" + + @pytest.mark.parametrize( + "tool_name", + [ + "ls", + "read_file", + "grep", + "glob", + "web_search", + "scrape_webpage", + "search_surfsense_docs", + "get_connected_accounts", + "write_todos", + "task", + "_noop", + "invalid", + "update_memory", + ], + ) + def test_default_allow_covers_safe_builtin(self, tool_name: str) -> None: + rulesets = _layered_rulesets(connector_denies=[]) + rules = evaluate_many(tool_name, [tool_name], *rulesets) + assert aggregate_action(rules) == "allow" + + +class TestConnectorDenyOverridesDefaultAllow: + """Connector-synthesized denies must beat the default-allow rule.""" + + def test_missing_connector_tool_is_denied(self) -> None: + rulesets = _layered_rulesets( + connector_denies=[ + Rule(permission="linear_create_issue", pattern="*", action="deny") + ] + ) + rules = evaluate_many( + "linear_create_issue", ["linear_create_issue"], *rulesets + ) + assert aggregate_action(rules) == "deny" + + def test_default_allow_still_applies_to_other_tools(self) -> None: + """A deny rule for one tool must not bleed onto unrelated calls.""" + rulesets = _layered_rulesets( + connector_denies=[ + Rule(permission="linear_create_issue", pattern="*", action="deny") + ] + ) + rules = evaluate_many("ls", ["ls"], *rulesets) + assert aggregate_action(rules) == "allow" + + +class TestUserRuleOverridesDefault: + """User rules layered last must override the default-allow rule.""" + + def test_user_ask_overrides_default_allow(self) -> None: + defaults = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ) + user_ruleset = Ruleset( + rules=[Rule(permission="ls", pattern="*", action="ask")], + origin="user", + ) + rules = evaluate_many("ls", ["ls"], defaults, user_ruleset) + assert aggregate_action(rules) == "ask" + + def test_user_deny_overrides_default_allow(self) -> None: + defaults = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ) + user_ruleset = Ruleset( + rules=[Rule(permission="send_*", pattern="*", action="deny")], + origin="user", + ) + rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset) + assert aggregate_action(rules) == "deny" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py new file mode 100644 index 000000000..c54163dc3 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py @@ -0,0 +1,99 @@ +"""Tests for DoomLoopMiddleware signature equality detection.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage + +from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware, _signature + +pytestmark = pytest.mark.unit + + +def test_signature_is_stable_for_identical_args() -> None: + a = _signature("search", {"q": "hello", "n": 10}) + b = _signature("search", {"n": 10, "q": "hello"}) + assert a == b + + +def test_signature_changes_with_args() -> None: + a = _signature("search", {"q": "hello"}) + b = _signature("search", {"q": "world"}) + assert a != b + + +def test_signature_changes_with_name() -> None: + a = _signature("search", {"q": "x"}) + b = _signature("read", {"q": "x"}) + assert a != b + + +class _FakeRuntime: + def __init__(self, thread_id: str | None = "thread-1") -> None: + self.config = {"configurable": {"thread_id": thread_id}} + + +def _msg_calling(name: str, args: dict, call_id: str) -> AIMessage: + return AIMessage( + content="", + tool_calls=[{"name": name, "args": args, "id": call_id}], + ) + + +def test_threshold_triggers_after_n_identical_calls() -> None: + mw = DoomLoopMiddleware(threshold=3) + runtime = _FakeRuntime() + + # First two calls — under threshold + for i in range(2): + out = mw.after_model( + {"messages": [_msg_calling("repeat", {"x": 1}, f"call-{i}")]}, + runtime, + ) + assert out is None + + # Third identical call should trigger ``langgraph.types.interrupt``. + # In a unit-test context (no runnable graph), ``interrupt`` raises + # ``RuntimeError`` because ``get_config`` has nothing to bind to — + # we accept that as proof the interrupt path was taken (the + # alternative would be no exception, which would mean the loop + # detection never fired). + with pytest.raises(Exception) as excinfo: + mw.after_model( + {"messages": [_msg_calling("repeat", {"x": 1}, "call-3")]}, + runtime, + ) + name = type(excinfo.value).__name__.lower() + assert ( + "interrupt" in name + or "runtimeerror" in name + ), f"Expected an interrupt-style exception, got {name}" + + +def test_does_not_trigger_when_args_differ() -> None: + mw = DoomLoopMiddleware(threshold=2) + runtime = _FakeRuntime() + out = mw.after_model( + {"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime + ) + assert out is None + out = mw.after_model( + {"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime + ) + assert out is None + + +def test_separate_threads_have_independent_windows() -> None: + mw = DoomLoopMiddleware(threshold=2) + rt_a = _FakeRuntime(thread_id="A") + rt_b = _FakeRuntime(thread_id="B") + + mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_a) + # thread B should NOT count thread A's call + out = mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_b) + assert out is None # not yet at threshold for B + + +def test_invalid_threshold_rejected() -> None: + with pytest.raises(ValueError): + DoomLoopMiddleware(threshold=1) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py new file mode 100644 index 000000000..38a70a443 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -0,0 +1,120 @@ +"""Tests for the agent feature-flag system.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.feature_flags import ( + AgentFeatureFlags, + reload_for_tests, +) + +pytestmark = pytest.mark.unit + + +def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: + for name in [ + "SURFSENSE_DISABLE_NEW_AGENT_STACK", + "SURFSENSE_ENABLE_CONTEXT_EDITING", + "SURFSENSE_ENABLE_COMPACTION_V2", + "SURFSENSE_ENABLE_RETRY_AFTER", + "SURFSENSE_ENABLE_MODEL_FALLBACK", + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", + "SURFSENSE_ENABLE_TOOL_CALL_LIMIT", + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", + "SURFSENSE_ENABLE_DOOM_LOOP", + "SURFSENSE_ENABLE_PERMISSION", + "SURFSENSE_ENABLE_BUSY_MUTEX", + "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", + "SURFSENSE_ENABLE_SKILLS", + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", + "SURFSENSE_ENABLE_ACTION_LOG", + "SURFSENSE_ENABLE_REVERT_ROUTE", + "SURFSENSE_ENABLE_PLUGIN_LOADER", + "SURFSENSE_ENABLE_OTEL", + ]: + monkeypatch.delenv(name, raising=False) + + +def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_all(monkeypatch) + flags = reload_for_tests() + assert isinstance(flags, AgentFeatureFlags) + assert flags.disable_new_agent_stack is False + assert flags.any_new_middleware_enabled() is False + + +def test_master_kill_switch_overrides_individual_flags( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _clear_all(monkeypatch) + monkeypatch.setenv("SURFSENSE_DISABLE_NEW_AGENT_STACK", "true") + monkeypatch.setenv("SURFSENSE_ENABLE_CONTEXT_EDITING", "true") + monkeypatch.setenv("SURFSENSE_ENABLE_PERMISSION", "true") + + flags = reload_for_tests() + assert flags.disable_new_agent_stack is True + assert flags.enable_context_editing is False + assert flags.enable_permission is False + assert flags.any_new_middleware_enabled() is False + + +@pytest.mark.parametrize("truthy", ["1", "true", "TRUE", "yes", "on"]) +def test_individual_flags_truthy_values( + monkeypatch: pytest.MonkeyPatch, truthy: str +) -> None: + _clear_all(monkeypatch) + monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", truthy) + flags = reload_for_tests() + assert flags.enable_retry_after is True + assert flags.any_new_middleware_enabled() is True + + +@pytest.mark.parametrize("falsy", ["0", "false", "no", "off", "", "garbage"]) +def test_individual_flags_falsy_values( + monkeypatch: pytest.MonkeyPatch, falsy: str +) -> None: + _clear_all(monkeypatch) + monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", falsy) + flags = reload_for_tests() + assert flags.enable_retry_after is False + + +def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_all(monkeypatch) + flag_to_env = { + "enable_context_editing": "SURFSENSE_ENABLE_CONTEXT_EDITING", + "enable_compaction_v2": "SURFSENSE_ENABLE_COMPACTION_V2", + "enable_retry_after": "SURFSENSE_ENABLE_RETRY_AFTER", + "enable_model_fallback": "SURFSENSE_ENABLE_MODEL_FALLBACK", + "enable_model_call_limit": "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", + "enable_tool_call_limit": "SURFSENSE_ENABLE_TOOL_CALL_LIMIT", + "enable_tool_call_repair": "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", + "enable_doom_loop": "SURFSENSE_ENABLE_DOOM_LOOP", + "enable_permission": "SURFSENSE_ENABLE_PERMISSION", + "enable_busy_mutex": "SURFSENSE_ENABLE_BUSY_MUTEX", + "enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", + "enable_skills": "SURFSENSE_ENABLE_SKILLS", + "enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", + "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", + "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", + "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", + "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", + "enable_otel": "SURFSENSE_ENABLE_OTEL", + } + + # `enable_otel` is intentionally orthogonal — it does NOT count toward + # ``any_new_middleware_enabled`` because OTel is observability-only and + # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement. + counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"} + + for attr, env_name in flag_to_env.items(): + _clear_all(monkeypatch) + monkeypatch.setenv(env_name, "true") + flags = reload_for_tests() + assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}" + if attr in counts_toward_middleware: + assert flags.any_new_middleware_enabled() is True + else: + assert flags.any_new_middleware_enabled() is False diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py new file mode 100644 index 000000000..8555eea76 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py @@ -0,0 +1,119 @@ +"""Tests for NoopInjectionMiddleware provider-compat logic.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from app.agents.new_chat.middleware.noop_injection import ( + NOOP_TOOL_NAME, + NoopInjectionMiddleware, + _last_ai_has_tool_calls, + _provider_needs_noop, +) + +pytestmark = pytest.mark.unit + + +class _LiteLLMModel: + def _get_ls_params(self): + return {"ls_provider": "litellm"} + + +class _BedrockModel: + def _get_ls_params(self): + return {"ls_provider": "bedrock"} + + +class _OpenAIModel: + def _get_ls_params(self): + return {"ls_provider": "openai"} + + +class _ChatLiteLLM: # name-only fallback + pass + + +class TestProviderDetection: + def test_litellm(self) -> None: + assert _provider_needs_noop(_LiteLLMModel()) is True + + def test_bedrock(self) -> None: + assert _provider_needs_noop(_BedrockModel()) is True + + def test_openai_does_not_need(self) -> None: + assert _provider_needs_noop(_OpenAIModel()) is False + + def test_class_name_fallback(self) -> None: + assert _provider_needs_noop(_ChatLiteLLM()) is True + + +class TestHistoryDetection: + def test_last_ai_has_tool_calls(self) -> None: + msgs = [ + HumanMessage(content="hi"), + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]), + ] + assert _last_ai_has_tool_calls(msgs) is True + + def test_last_ai_no_tool_calls(self) -> None: + msgs = [ + HumanMessage(content="hi"), + AIMessage(content="hello"), + ] + assert _last_ai_has_tool_calls(msgs) is False + + def test_no_ai_in_history(self) -> None: + assert _last_ai_has_tool_calls([HumanMessage(content="hi")]) is False + + +class _FakeRequest: + def __init__(self, *, tools, messages, model) -> None: + self.tools = tools + self.messages = messages + self.model = model + + def override(self, *, tools): + return _FakeRequest(tools=tools, messages=self.messages, model=self.model) + + +class TestShouldInject: + def test_injects_when_all_conditions_met(self) -> None: + mw = NoopInjectionMiddleware() + msgs = [ + HumanMessage(content="hi"), + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]), + ] + req = _FakeRequest(tools=[], messages=msgs, model=_LiteLLMModel()) + assert mw._should_inject(req) is True + + def test_skips_when_tools_present(self) -> None: + mw = NoopInjectionMiddleware() + req = _FakeRequest( + tools=[object()], + messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])], + model=_LiteLLMModel(), + ) + assert mw._should_inject(req) is False + + def test_skips_when_no_history_tool_calls(self) -> None: + mw = NoopInjectionMiddleware() + req = _FakeRequest( + tools=[], + messages=[HumanMessage(content="hi")], + model=_LiteLLMModel(), + ) + assert mw._should_inject(req) is False + + def test_skips_for_openai(self) -> None: + mw = NoopInjectionMiddleware() + req = _FakeRequest( + tools=[], + messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])], + model=_OpenAIModel(), + ) + assert mw._should_inject(req) is False + + +def test_noop_tool_name_is_underscore_noop() -> None: + assert NOOP_TOOL_NAME == "_noop" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py new file mode 100644 index 000000000..e5b171612 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py @@ -0,0 +1,195 @@ +"""Tests for the OtelSpanMiddleware adapter (Tier 3b).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, ToolMessage + +from app.agents.new_chat.middleware.otel_span import ( + OtelSpanMiddleware, + _annotate_model_response, + _annotate_tool_result, + _resolve_input_size, + _resolve_model_attrs, + _resolve_tool_name, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _disable_otel(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + from app.observability import otel as ot + + ot.reload_for_tests() + yield + ot.reload_for_tests() + + +class TestResolveModelAttrs: + def test_extracts_model_name_and_provider(self) -> None: + request = MagicMock() + request.model = MagicMock(spec=["model_name", "provider"]) + request.model.model_name = "gpt-4o-mini" + request.model.provider = "openai" + assert _resolve_model_attrs(request) == ("gpt-4o-mini", "openai") + + def test_handles_missing_model(self) -> None: + request = MagicMock() + request.model = None + assert _resolve_model_attrs(request) == (None, None) + + def test_falls_back_through_attribute_chain(self) -> None: + request = MagicMock() + request.model = MagicMock(spec=["model_id", "_llm_type"]) + request.model.model_id = "claude-3-5-sonnet" + request.model._llm_type = "anthropic-chat" + model_id, provider = _resolve_model_attrs(request) + assert model_id == "claude-3-5-sonnet" + assert provider == "anthropic-chat" + + +class TestResolveToolName: + def test_prefers_request_tool_name(self) -> None: + request = MagicMock() + request.tool = MagicMock(name="ToolStub") + request.tool.name = "scrape_webpage" + assert _resolve_tool_name(request) == "scrape_webpage" + + def test_falls_back_to_tool_call_name(self) -> None: + request = MagicMock() + request.tool = None + request.tool_call = {"name": "web_search", "args": {}} + assert _resolve_tool_name(request) == "web_search" + + def test_unknown_when_nothing_resolves(self) -> None: + request = MagicMock() + request.tool = None + request.tool_call = {} + assert _resolve_tool_name(request) == "unknown" + + +class TestResolveInputSize: + def test_returns_repr_length_of_args(self) -> None: + request = MagicMock() + request.tool_call = {"args": {"query": "hello world"}} + size = _resolve_input_size(request) + assert isinstance(size, int) + assert size > 0 + + def test_handles_no_tool_call(self) -> None: + request = MagicMock() + request.tool_call = None + assert _resolve_input_size(request) is None + + +class TestAnnotateModelResponse: + def test_attaches_token_counts_when_present(self) -> None: + sp = MagicMock() + msg = AIMessage( + content="hello", + usage_metadata={ + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + }, + ) + _annotate_model_response(sp, msg) + sp.set_attribute.assert_any_call("tokens.prompt", 100) + sp.set_attribute.assert_any_call("tokens.completion", 50) + sp.set_attribute.assert_any_call("tokens.total", 150) + + def test_handles_response_with_no_metadata(self) -> None: + sp = MagicMock() + msg = AIMessage(content="hello") + # Should not raise even when usage_metadata is missing + _annotate_model_response(sp, msg) + + +class TestAnnotateToolResult: + def test_records_size_and_status(self) -> None: + sp = MagicMock() + result = ToolMessage( + content="result text", + tool_call_id="abc", + status="success", + ) + _annotate_tool_result(sp, result) + sp.set_attribute.assert_any_call("tool.output.size", len("result text")) + sp.set_attribute.assert_any_call("tool.status", "success") + + def test_marks_errors(self) -> None: + sp = MagicMock() + result = ToolMessage( + content="oops", + tool_call_id="abc", + additional_kwargs={"error": {"code": "x"}}, + ) + _annotate_tool_result(sp, result) + sp.set_attribute.assert_any_call("tool.error", True) + + +@pytest.mark.asyncio +class TestMiddlewareIntegration: + async def test_awrap_model_call_passes_through_when_disabled(self) -> None: + mw = OtelSpanMiddleware() + called: dict[str, Any] = {} + + async def handler(req): + called["req"] = req + return AIMessage(content="ok") + + request = MagicMock() + result = await mw.awrap_model_call(request, handler) + assert called["req"] is request + assert isinstance(result, AIMessage) + assert result.content == "ok" + + async def test_awrap_tool_call_passes_through_when_disabled(self) -> None: + mw = OtelSpanMiddleware() + + async def handler(req): + return ToolMessage(content="result", tool_call_id="abc") + + request = MagicMock() + result = await mw.awrap_tool_call(request, handler) + assert isinstance(result, ToolMessage) + assert result.content == "result" + + async def test_awrap_model_call_propagates_exceptions(self) -> None: + mw = OtelSpanMiddleware() + + async def handler(req): + raise ValueError("boom") + + with pytest.raises(ValueError): + await mw.awrap_model_call(MagicMock(), handler) + + async def test_with_otel_enabled_does_not_alter_result( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + from app.observability import otel as ot + + ot.reload_for_tests() + try: + mw = OtelSpanMiddleware() + + async def handler(req): + return AIMessage(content="enabled") + + request = MagicMock() + request.model = MagicMock() + request.model.model_name = "gpt-4o" + request.model.provider = "openai" + result = await mw.awrap_model_call(request, handler) + assert isinstance(result, AIMessage) + assert result.content == "enabled" + finally: + ot.reload_for_tests() diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py new file mode 100644 index 000000000..194a6eb27 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py @@ -0,0 +1,116 @@ +"""Tests for PermissionMiddleware end-to-end behavior.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, ToolMessage + +from app.agents.new_chat.errors import CorrectedError, RejectedError +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.permissions import Rule, Ruleset + +pytestmark = pytest.mark.unit + + +class _FakeRuntime: + config: dict = {"configurable": {"thread_id": "test"}} + + +def _msg(*tool_calls: dict) -> AIMessage: + return AIMessage(content="", tool_calls=list(tool_calls)) + + +class TestAllow: + def test_passthrough_when_allow(self) -> None: + rs = Ruleset(rules=[Rule("send_email", "*", "allow")]) + mw = PermissionMiddleware(rulesets=[rs]) + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + assert out is None # no change + + +class TestDeny: + def test_replaces_with_deny_tool_message(self) -> None: + rs = Ruleset(rules=[Rule("send_email", "*", "deny")]) + mw = PermissionMiddleware(rulesets=[rs]) + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + assert out is not None + msgs = out["messages"] + # Find the deny ToolMessage + deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)] + assert len(deny_msgs) == 1 + assert deny_msgs[0].status == "error" + assert "permission_denied" in str(deny_msgs[0].additional_kwargs) + # AIMessage's tool_calls should now be empty (denied call removed) + ai_msg = next(m for m in msgs if isinstance(m, AIMessage)) + assert ai_msg.tool_calls == [] + + def test_mixed_allow_deny(self) -> None: + rs = Ruleset( + rules=[ + Rule("send_email", "*", "deny"), + Rule("read", "*", "allow"), + ] + ) + mw = PermissionMiddleware(rulesets=[rs]) + state = { + "messages": [ + _msg( + {"name": "send_email", "args": {}, "id": "1"}, + {"name": "read", "args": {}, "id": "2"}, + ) + ] + } + out = mw.after_model(state, _FakeRuntime()) + assert out is not None + ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage)) + assert len(ai_msg.tool_calls) == 1 + assert ai_msg.tool_calls[0]["name"] == "read" + + +class TestAsk: + def test_reject_without_feedback_raises(self) -> None: + # Default: nothing matches -> ask + rs = Ruleset(rules=[]) + mw = PermissionMiddleware(rulesets=[rs]) + + # Bypass real interrupt — patch the helper + mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + with pytest.raises(RejectedError): + mw.after_model(state, _FakeRuntime()) + + def test_reject_with_feedback_raises_corrected(self) -> None: + rs = Ruleset(rules=[]) + mw = PermissionMiddleware(rulesets=[rs]) + mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] + "decision_type": "reject", + "feedback": "use a different subject line", + } + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + with pytest.raises(CorrectedError) as excinfo: + mw.after_model(state, _FakeRuntime()) + assert excinfo.value.feedback == "use a different subject line" + + def test_once_proceeds_without_persisting(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment] + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + # No state change because all calls kept + assert out is None + # No new rule persisted + assert mw._runtime_ruleset.rules == [] + + def test_always_persists_runtime_rule(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment] + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + assert out is None # call kept + # Runtime ruleset got the always-allow rule + new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] + assert any( + r.permission == "send_email" for r in new_rules + ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py new file mode 100644 index 000000000..4924f2aee --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py @@ -0,0 +1,111 @@ +"""Tests for the wildcard matcher and rule evaluator (opencode evaluate.ts parity).""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate, + evaluate_many, + wildcard_match, +) + +pytestmark = pytest.mark.unit + + +class TestWildcardMatch: + @pytest.mark.parametrize( + "value,pattern,expected", + [ + ("edit", "edit", True), + ("edit", "*", True), + ("read", "edit", False), + ("/documents/secrets/x", "/documents/secrets/**", True), + # Single-segment glob: '*' does not cross '/' + ("/documents/secrets/x", "/documents/*/x", True), + ("/documents/foo/bar/x", "/documents/*/x", False), + ("/documents/foo/x", "/documents/*/x", True), + ("linear_create", "linear_*", True), + ("notion_create", "linear_*", False), + # ':' is not a separator, so '*' matches it + ("mcp:notion:create_page", "mcp:*", True), + ("mcp:notion:create_page", "mcp:**", True), + # But '/' IS a separator + ("foo/bar", "foo/*", True), + ("foo/bar/baz", "foo/*", False), + ], + ) + def test_match(self, value: str, pattern: str, expected: bool) -> None: + assert wildcard_match(value, pattern) is expected + + +class TestEvaluate: + def test_default_action_is_ask(self) -> None: + rule = evaluate("edit", "/foo/bar") + assert rule.action == "ask" + assert rule.permission == "edit" + + def test_last_match_wins(self) -> None: + rs = Ruleset( + rules=[ + Rule("edit", "*", "allow"), + Rule("edit", "/secrets/**", "deny"), + ] + ) + # Second rule (deny) is more specific AND specified later + assert evaluate("edit", "/secrets/x", rs).action == "deny" + # First rule (allow) covers the rest + assert evaluate("edit", "/public/x", rs).action == "allow" + + def test_layered_rulesets_later_overrides_earlier(self) -> None: + defaults = Ruleset(rules=[Rule("edit", "*", "ask")], origin="defaults") + space = Ruleset(rules=[Rule("edit", "*", "allow")], origin="space") + thread = Ruleset(rules=[Rule("edit", "*", "deny")], origin="thread") + # All three layered: thread wins + assert evaluate("edit", "x", defaults, space, thread).action == "deny" + # Without thread: space wins + assert evaluate("edit", "x", defaults, space).action == "allow" + + def test_permission_wildcard(self) -> None: + rs = Ruleset(rules=[Rule("linear_*", "*", "allow")]) + assert evaluate("linear_create_issue", "x", rs).action == "allow" + assert evaluate("notion_create", "x", rs).action == "ask" + + def test_pattern_wildcard(self) -> None: + rs = Ruleset(rules=[Rule("edit", "/documents/secrets/**", "deny")]) + assert evaluate("edit", "/documents/secrets/foo", rs).action == "deny" + assert evaluate("edit", "/documents/public/foo", rs).action == "ask" + + def test_evaluate_many(self) -> None: + rs = Ruleset( + rules=[ + Rule("edit", "*", "allow"), + Rule("edit", "/secrets/*", "deny"), + ] + ) + results = evaluate_many("edit", ["/public/x", "/secrets/y"], rs) + assert [r.action for r in results] == ["allow", "deny"] + + +class TestAggregateAction: + def test_any_deny_means_deny(self) -> None: + rules = [ + Rule("a", "*", "allow"), + Rule("a", "*", "deny"), + Rule("a", "*", "ask"), + ] + assert aggregate_action(rules) == "deny" + + def test_any_ask_means_ask_when_no_deny(self) -> None: + rules = [Rule("a", "*", "allow"), Rule("a", "*", "ask")] + assert aggregate_action(rules) == "ask" + + def test_all_allow_means_allow(self) -> None: + rules = [Rule("a", "*", "allow"), Rule("a", "*", "allow")] + assert aggregate_action(rules) == "allow" + + def test_empty_means_ask(self) -> None: + assert aggregate_action([]) == "ask" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py new file mode 100644 index 000000000..8d98e1328 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py @@ -0,0 +1,187 @@ +"""Unit tests for the SurfSense plugin entry-point loader (Tier 6).""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from langchain.agents.middleware import AgentMiddleware + +from app.agents.new_chat.plugin_loader import ( + PLUGIN_ENTRY_POINT_GROUP, + PluginContext, + load_allowed_plugin_names_from_env, + load_plugin_middlewares, +) +from app.agents.new_chat.plugins.year_substituter import ( + _YearSubstituterMiddleware, + make_middleware as year_substituter_factory, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _DummyMiddleware(AgentMiddleware): + """Trivial middleware used as the success-path return value.""" + + tools = () + + +def _ctx() -> PluginContext: + return PluginContext.build( + search_space_id=1, + user_id="u", + thread_visibility="PRIVATE", # type: ignore[arg-type] + llm=MagicMock(), + ) + + +class _FakeEntryPoint: + """Stand-in for ``importlib.metadata.EntryPoint``.""" + + def __init__(self, name: str, factory) -> None: + self.name = name + self._factory = factory + + def load(self): + return self._factory + + +# --------------------------------------------------------------------------- +# Loader behaviour +# --------------------------------------------------------------------------- + + +class TestPluginLoaderBasics: + def test_returns_empty_when_allowlist_is_empty(self) -> None: + assert load_plugin_middlewares(_ctx(), allowed_plugin_names=[]) == [] + + def test_skips_non_allowlisted_plugin(self) -> None: + called = [] + + def factory(_): # would be an obvious bug if called + called.append(True) + return _DummyMiddleware() + + ep = _FakeEntryPoint("dangerous_plugin", factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"]) + assert result == [] + assert not called + + def test_loads_allowlisted_plugin(self) -> None: + ep = _FakeEntryPoint("year_substituter", year_substituter_factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares( + _ctx(), allowed_plugin_names={"year_substituter"} + ) + assert len(result) == 1 + assert isinstance(result[0], _YearSubstituterMiddleware) + + +class TestPluginLoaderIsolation: + def test_factory_exception_is_isolated(self) -> None: + def crashing_factory(_): + raise RuntimeError("boom") + + ep = _FakeEntryPoint("buggy", crashing_factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"buggy"}) + assert result == [] # construction continued without the plugin + + def test_non_middleware_return_is_rejected(self) -> None: + def bad_factory(_): + return "not a middleware" + + ep = _FakeEntryPoint("liar", bad_factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"liar"}) + assert result == [] + + def test_load_phase_exception_is_isolated(self) -> None: + class _BrokenEP: + name = "broken" + + def load(self): + raise ImportError("cannot import") + + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[_BrokenEP()], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"broken"}) + assert result == [] + + def test_one_failure_does_not_block_others(self) -> None: + """Two plugins; one crashes during factory; the other still loads.""" + + def crashing_factory(_): + raise RuntimeError("boom") + + eps = [ + _FakeEntryPoint("crashing", crashing_factory), + _FakeEntryPoint("ok", year_substituter_factory), + ] + with patch( + "app.agents.new_chat.plugin_loader.entry_points", return_value=eps + ): + result = load_plugin_middlewares( + _ctx(), allowed_plugin_names={"crashing", "ok"} + ) + assert len(result) == 1 + assert isinstance(result[0], _YearSubstituterMiddleware) + + +class TestAllowlistEnv: + def test_empty_env_returns_empty_set(self, monkeypatch) -> None: + monkeypatch.delenv("SURFSENSE_ALLOWED_PLUGINS", raising=False) + assert load_allowed_plugin_names_from_env() == set() + + def test_parses_comma_separated_value(self, monkeypatch) -> None: + monkeypatch.setenv( + "SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , " + ) + assert load_allowed_plugin_names_from_env() == { + "year_substituter", + "noisy", + } + + +class TestPluginContext: + def test_build_includes_required_fields(self) -> None: + llm = MagicMock() + ctx = PluginContext.build( + search_space_id=42, + user_id="user-1", + thread_visibility="PRIVATE", # type: ignore[arg-type] + llm=llm, + ) + assert ctx["search_space_id"] == 42 + assert ctx["user_id"] == "user-1" + assert ctx["llm"] is llm + + def test_does_not_carry_secrets_or_db_session(self) -> None: + ctx = _ctx() + # If a future change tries to add these keys, this test will fail loudly. + for forbidden in ("api_key", "secret", "db_session", "session"): + assert forbidden not in ctx + + +class TestEntryPointGroup: + def test_group_name_matches_pyproject_convention(self) -> None: + # Plugins register under `surfsense.plugins`; this is part of our + # public contract for plugin authors. + assert PLUGIN_ENTRY_POINT_GROUP == "surfsense.plugins" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py new file mode 100644 index 000000000..39dd9bf00 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py @@ -0,0 +1,107 @@ +"""Tests for RetryAfterMiddleware Retry-After parsing and retry decision logic.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.middleware.retry_after import ( + RetryAfterMiddleware, + _extract_retry_after_seconds, + _is_non_retryable, +) + +pytestmark = pytest.mark.unit + + +class _FakeResponse: + def __init__(self, headers: dict[str, str]) -> None: + self.headers = headers + + +class _FakeRateLimit(Exception): + def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None: + super().__init__(msg) + if headers is not None: + self.response = _FakeResponse(headers) + + +class TestExtractRetryAfter: + def test_seconds_header(self) -> None: + exc = _FakeRateLimit("rate", {"Retry-After": "30"}) + assert _extract_retry_after_seconds(exc) == 30.0 + + def test_milliseconds_header_overrides_seconds(self) -> None: + exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"}) + assert _extract_retry_after_seconds(exc) == 1.5 + + def test_case_insensitive(self) -> None: + exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"}) + assert _extract_retry_after_seconds(exc) == 12.0 + + def test_falls_back_to_message_regex(self) -> None: + exc = Exception("Please retry after 7 seconds") + assert _extract_retry_after_seconds(exc) == 7.0 + + def test_returns_none_when_no_hint(self) -> None: + exc = Exception("oops") + assert _extract_retry_after_seconds(exc) is None + + def test_handles_missing_headers_attr(self) -> None: + exc = ValueError("no headers") + assert _extract_retry_after_seconds(exc) is None + + +class TestIsNonRetryable: + @pytest.mark.parametrize( + "name", + ["ContextWindowExceededError", "AuthenticationError", "InvalidRequestError"], + ) + def test_non_retryable_classes(self, name: str) -> None: + cls = type(name, (Exception,), {}) + assert _is_non_retryable(cls("x")) is True + + def test_generic_exception_is_retryable(self) -> None: + assert _is_non_retryable(RuntimeError("transient")) is False + + +class TestDelayCalculation: + def test_takes_max_of_backoff_and_header(self) -> None: + mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False) + exc = _FakeRateLimit("rl", {"retry-after": "10"}) + delay = mw._delay_for_attempt(0, exc) + assert delay == pytest.approx(10.0) + + def test_uses_backoff_when_no_header(self) -> None: + mw = RetryAfterMiddleware( + max_retries=3, initial_delay=2.0, backoff_factor=2.0, jitter=False + ) + delay = mw._delay_for_attempt(2, RuntimeError("transient")) + # 2 * 2^2 = 8 + assert delay == pytest.approx(8.0) + + def test_caps_at_max_delay(self) -> None: + mw = RetryAfterMiddleware( + max_retries=3, + initial_delay=10.0, + backoff_factor=10.0, + max_delay=15.0, + jitter=False, + ) + delay = mw._delay_for_attempt(5, RuntimeError("x")) + assert delay <= 15.0 + + +class TestShouldRetry: + def test_default_retries_generic(self) -> None: + mw = RetryAfterMiddleware() + assert mw._should_retry(RuntimeError("transient")) is True + + def test_default_skips_non_retryable(self) -> None: + mw = RetryAfterMiddleware() + cls = type("ContextWindowExceededError", (Exception,), {}) + assert mw._should_retry(cls("too big")) is False + + def test_custom_retry_on(self) -> None: + mw = RetryAfterMiddleware(retry_on=lambda exc: isinstance(exc, ValueError)) + assert mw._should_retry(ValueError()) is True + assert mw._should_retry(KeyError()) is False diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py b/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py new file mode 100644 index 000000000..eb9cf396c --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py @@ -0,0 +1,242 @@ +"""Tests for the skills backends used by SurfSense's SkillsMiddleware.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from app.agents.new_chat.middleware.skills_backends import ( + SKILLS_BUILTIN_PREFIX, + SKILLS_SPACE_PREFIX, + BuiltinSkillsBackend, + SearchSpaceSkillsBackend, + build_skills_backend_factory, + default_skills_sources, +) + + +@pytest.fixture +def skills_root(tmp_path: Path) -> Path: + """Build a small synthetic skill-tree used by the tests.""" + root = tmp_path / "skills" + (root / "alpha").mkdir(parents=True) + (root / "alpha" / "SKILL.md").write_text( + "---\nname: alpha\ndescription: alpha skill\n---\n# Alpha\n" + ) + (root / "beta").mkdir(parents=True) + (root / "beta" / "SKILL.md").write_text( + "---\nname: beta\ndescription: beta skill\n---\n# Beta\n" + ) + (root / "_orphan_file.md").write_text("not a skill, just a stray file") + return root + + +class TestBuiltinSkillsBackendListing: + def test_lists_skill_directories_at_root(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + infos = backend.ls_info("/") + names = {info["path"] for info in infos} + assert "/alpha" in names + assert "/beta" in names + assert "/_orphan_file.md" in names + for info in infos: + if info["path"] in {"/alpha", "/beta"}: + assert info["is_dir"] is True + + def test_lists_skill_md_under_skill_directory(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + infos = backend.ls_info("/alpha") + paths = {info["path"] for info in infos} + assert paths == {"/alpha/SKILL.md"} + assert infos[0]["is_dir"] is False + assert infos[0]["size"] > 0 + + def test_returns_empty_for_missing_path(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + assert backend.ls_info("/nonexistent") == [] + + def test_returns_empty_when_root_missing(self, tmp_path: Path) -> None: + backend = BuiltinSkillsBackend(tmp_path / "definitely-missing") + assert backend.ls_info("/") == [] + assert backend.download_files(["/x/SKILL.md"])[0].error == "file_not_found" + + def test_refuses_path_traversal(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + assert backend.ls_info("/../../../etc") == [] + responses = backend.download_files(["/../../../etc/passwd"]) + assert responses[0].error == "invalid_path" + + +class TestBuiltinSkillsBackendDownload: + def test_downloads_skill_md_content(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + responses = backend.download_files(["/alpha/SKILL.md", "/beta/SKILL.md"]) + assert len(responses) == 2 + assert responses[0].path == "/alpha/SKILL.md" + assert responses[0].content is not None + assert b"name: alpha" in responses[0].content + assert responses[1].error is None + + def test_marks_directory_as_is_directory_error(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + responses = backend.download_files(["/alpha"]) + assert responses[0].error == "is_directory" + + def test_marks_missing_file_as_file_not_found(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + responses = backend.download_files(["/alpha/missing.md"]) + assert responses[0].error == "file_not_found" + assert responses[0].content is None + + def test_response_path_matches_input_for_correlation( + self, skills_root: Path + ) -> None: + backend = BuiltinSkillsBackend(skills_root) + inputs = ["/alpha/SKILL.md", "/missing.md", "/beta/SKILL.md"] + responses = backend.download_files(inputs) + assert [r.path for r in responses] == inputs + + +class TestBuiltinSkillsBackendIntegration: + """Mirror the call sequence the SkillsMiddleware actually uses.""" + + def test_skills_middleware_call_pattern(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + + infos = asyncio.run(backend.als_info("/")) + skill_dirs = [i["path"] for i in infos if i.get("is_dir")] + assert sorted(skill_dirs) == ["/alpha", "/beta"] + + skill_md_paths = [f"{p}/SKILL.md" for p in skill_dirs] + responses = asyncio.run(backend.adownload_files(skill_md_paths)) + assert all(r.error is None for r in responses) + assert all(r.content is not None for r in responses) + + +class TestBundledSkills: + def test_default_root_resolves_to_repo_skills_dir(self) -> None: + backend = BuiltinSkillsBackend() + assert backend.root.name == "builtin" + assert backend.root.parent.name == "skills" + + def test_bundled_starter_skills_are_present(self) -> None: + backend = BuiltinSkillsBackend() + infos = backend.ls_info("/") + names = {info["path"].lstrip("/") for info in infos if info.get("is_dir")} + # Five starter skills required by the Tier 4 plan. + for required in ( + "kb-research", + "report-writing", + "meeting-prep", + "slack-summary", + "email-drafting", + ): + assert required in names, f"missing starter skill: {required}" + + def test_each_starter_skill_has_valid_skill_md(self) -> None: + backend = BuiltinSkillsBackend() + infos = backend.ls_info("/") + skill_dirs = [info["path"] for info in infos if info.get("is_dir")] + for skill_dir in skill_dirs: + md_path = f"{skill_dir}/SKILL.md" + response = backend.download_files([md_path])[0] + assert response.error is None, f"missing SKILL.md in {skill_dir}" + content = response.content.decode("utf-8").replace("\r\n", "\n") + assert content.startswith("---\n"), f"missing frontmatter in {skill_dir}" + assert "\nname:" in content + assert "\ndescription:" in content + + +class _FakeKBBackend: + """Stand-in for :class:`KBPostgresBackend` with the two methods we need.""" + + def __init__(self, listing: list[dict], file_contents: dict[str, bytes]) -> None: + self._listing = listing + self._file_contents = file_contents + self.last_ls_path: str | None = None + self.last_download_paths: list[str] | None = None + + async def als_info(self, path: str): + self.last_ls_path = path + return self._listing + + async def adownload_files(self, paths): + from deepagents.backends.protocol import FileDownloadResponse + + self.last_download_paths = list(paths) + out: list[FileDownloadResponse] = [] + for p in paths: + content = self._file_contents.get(p) + if content is None: + out.append(FileDownloadResponse(path=p, error="file_not_found")) + else: + out.append(FileDownloadResponse(path=p, content=content)) + return out + + +class TestSearchSpaceSkillsBackend: + def test_remaps_paths_when_listing(self) -> None: + listing = [ + {"path": "/documents/_skills/policy", "is_dir": True}, + {"path": "/documents/_skills/policy/SKILL.md", "is_dir": False}, + {"path": "/documents/other-folder/x.md", "is_dir": False}, + ] + kb = _FakeKBBackend(listing=listing, file_contents={}) + backend = SearchSpaceSkillsBackend(kb) + infos = asyncio.run(backend.als_info("/")) + assert kb.last_ls_path == "/documents/_skills" + paths = [info["path"] for info in infos] + assert "/policy" in paths + assert "/policy/SKILL.md" in paths + # Unrelated KB documents must NOT leak into the skills namespace. + assert all(not p.startswith("/documents") for p in paths) + + def test_remaps_paths_when_downloading(self) -> None: + kb = _FakeKBBackend( + listing=[], + file_contents={ + "/documents/_skills/policy/SKILL.md": b"---\nname: policy\n---\n", + }, + ) + backend = SearchSpaceSkillsBackend(kb) + responses = asyncio.run(backend.adownload_files(["/policy/SKILL.md"])) + assert kb.last_download_paths == ["/documents/_skills/policy/SKILL.md"] + assert responses[0].path == "/policy/SKILL.md" + assert responses[0].error is None + assert responses[0].content is not None + + def test_sync_methods_raise_not_implemented(self) -> None: + backend = SearchSpaceSkillsBackend(_FakeKBBackend([], {})) + with pytest.raises(NotImplementedError): + backend.ls_info("/") + with pytest.raises(NotImplementedError): + backend.download_files(["/x"]) + + def test_custom_kb_root_is_honored(self) -> None: + kb = _FakeKBBackend( + listing=[ + {"path": "/skills_admin/x", "is_dir": True}, + ], + file_contents={}, + ) + backend = SearchSpaceSkillsBackend(kb, kb_root="/skills_admin") + infos = asyncio.run(backend.als_info("/")) + assert kb.last_ls_path == "/skills_admin" + assert infos[0]["path"] == "/x" + + +class TestBackendFactory: + def test_builtin_only_factory_returns_composite(self) -> None: + factory = build_skills_backend_factory() + backend = factory(runtime=None) # type: ignore[arg-type] + from deepagents.backends.composite import CompositeBackend + + assert isinstance(backend, CompositeBackend) + assert SKILLS_BUILTIN_PREFIX in backend.routes + assert SKILLS_SPACE_PREFIX not in backend.routes + + def test_default_skills_sources_lists_builtin_then_space(self) -> None: + sources = default_skills_sources() + assert sources == [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py new file mode 100644 index 000000000..3819b4605 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py @@ -0,0 +1,338 @@ +"""Tests for the specialized subagents (explore / report_writer / connector_negotiator).""" + +from __future__ import annotations + +from langchain_core.tools import tool + +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.subagents import ( + build_connector_negotiator_subagent, + build_explore_subagent, + build_report_writer_subagent, + build_specialized_subagents, +) +from app.agents.new_chat.subagents.config import ( + EXPLORE_READ_TOOLS, + REPORT_WRITER_TOOLS, + WRITE_TOOL_DENY_PATTERNS, +) + +# --------------------------------------------------------------------------- +# Fake tools used to verify filtering & permission behavior +# --------------------------------------------------------------------------- + + +@tool +def search_surfsense_docs(query: str) -> str: + """Search the user's KB.""" + return "" + + +@tool +def web_search(query: str) -> str: + """Search the public web.""" + return "" + + +@tool +def scrape_webpage(url: str) -> str: + """Scrape a single webpage.""" + return "" + + +@tool +def read_file(path: str) -> str: + """Read a file.""" + return "" + + +@tool +def ls_tree(path: str) -> str: + """List a tree.""" + return "" + + +@tool +def grep(pattern: str) -> str: + """Grep.""" + return "" + + +@tool +def update_memory(content: str) -> str: + """Update the user's memory.""" + return "" + + +@tool +def edit_file(path: str, old: str, new: str) -> str: + """Edit a file.""" + return "" + + +@tool +def linear_create_issue(title: str) -> str: + """Create a Linear issue.""" + return "" + + +@tool +def slack_send_message(channel: str, text: str) -> str: + """Send a Slack message.""" + return "" + + +@tool +def get_connected_accounts() -> str: + """List connected accounts.""" + return "" + + +@tool +def generate_report(topic: str) -> str: + """Generate a report artifact.""" + return "" + + +ALL_TOOLS = [ + search_surfsense_docs, + web_search, + scrape_webpage, + read_file, + ls_tree, + grep, + update_memory, + edit_file, + linear_create_issue, + slack_send_message, + get_connected_accounts, + generate_report, +] + + +class TestExploreSubagent: + def test_only_read_tools_are_exposed(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + assert names == EXPLORE_READ_TOOLS & {t.name for t in ALL_TOOLS} + assert "update_memory" not in names + assert "linear_create_issue" not in names + assert "edit_file" not in names + + def test_includes_permission_middleware_with_deny_rules(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + permission_mws = [ + m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] + ] + assert len(permission_mws) == 1 + ruleset = permission_mws[0]._static_rulesets[0] + assert ruleset.origin == "subagent_explore" + deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} + assert "update_memory" in deny_patterns + assert "edit_file" in deny_patterns + assert "*create*" in deny_patterns + assert "*send*" in deny_patterns + + def test_skills_inherits_default_sources(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + assert spec["skills"] == ["/skills/builtin/", "/skills/space/"] # type: ignore[index] + + def test_name_and_description_match_contract(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + assert spec["name"] == "explore" + assert "read-only" in spec["description"].lower() + + def test_includes_dedup_and_patch_middleware(self) -> None: + from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware + + from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware + + spec = build_explore_subagent(tools=ALL_TOOLS) + types = {type(m) for m in spec["middleware"]} # type: ignore[index] + assert PatchToolCallsMiddleware in types + assert DedupHITLToolCallsMiddleware in types + + +class TestReportWriterSubagent: + def test_exposes_only_report_writing_tools(self) -> None: + spec = build_report_writer_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS} + assert "generate_report" in names + assert "search_surfsense_docs" in names + + def test_deny_rules_block_writes_but_allow_generate_report(self) -> None: + spec = build_report_writer_subagent(tools=ALL_TOOLS) + permission_mws = [ + m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] + ] + ruleset = permission_mws[0]._static_rulesets[0] + deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} + assert "update_memory" in deny_patterns + # generate_report MUST not be denied — it's the whole point of the subagent. + assert "generate_report" not in deny_patterns + # No deny pattern should match `generate_report` either. + assert all( + not _wildcard_matches(pattern, "generate_report") + for pattern in deny_patterns + ) + + +class TestConnectorNegotiatorSubagent: + def test_inherits_all_parent_tools(self) -> None: + spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + # Every parent tool is inherited; the deny ruleset enforces behavior + # at execution time instead of trimming the tool list. + assert names == {t.name for t in ALL_TOOLS} + + def test_get_connected_accounts_is_present(self) -> None: + spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + assert "get_connected_accounts" in names + + def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None: + spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) + permission_mws = [ + m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] + ] + ruleset = permission_mws[0]._static_rulesets[0] + deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} + # `linear_create_issue` matches the `*_create` deny pattern. + assert any( + _wildcard_matches(p, "linear_create_issue") for p in deny_patterns + ) + assert any( + _wildcard_matches(p, "slack_send_message") for p in deny_patterns + ) + + +class TestBuildSpecializedSubagents: + def test_returns_three_specs(self) -> None: + specs = build_specialized_subagents(tools=ALL_TOOLS) + names = [s["name"] for s in specs] # type: ignore[index] + assert names == ["explore", "report_writer", "connector_negotiator"] + + def test_all_specs_have_unique_names(self) -> None: + specs = build_specialized_subagents(tools=ALL_TOOLS) + names = [s["name"] for s in specs] # type: ignore[index] + assert len(set(names)) == len(names) + + def test_extra_middleware_is_prepended_to_each_spec(self) -> None: + """Sentinel middleware passed via ``extra_middleware`` must appear + in each subagent's ``middleware`` list, before the local rules. + + This guards against the regression where specialized subagents + promised filesystem tools (``read_file``, ``ls``, ``grep``) in + their system prompts but had no filesystem middleware mounted. + """ + + class _Sentinel: + pass + + sentinel = _Sentinel() + specs = build_specialized_subagents( + tools=ALL_TOOLS, extra_middleware=[sentinel] + ) + for spec in specs: + mws = spec["middleware"] # type: ignore[index] + assert sentinel in mws + # The sentinel must appear *before* the permission middleware + # (subagent-local rules), preserving the documented composition + # order: extra → custom → patch → dedup. + sentinel_idx = mws.index(sentinel) + perm_idx = next( + (i for i, m in enumerate(mws) + if isinstance(m, PermissionMiddleware)), + None, + ) + assert perm_idx is not None + assert sentinel_idx < perm_idx + + +class TestFilterToolsWarningSuppression: + """Names provided by middleware (read_file, ls, grep, …) must not + trigger the spurious "missing" warning in :func:`_filter_tools`.""" + + def test_middleware_provided_names_are_silent(self, caplog) -> None: + import logging + + from app.agents.new_chat.subagents.config import _filter_tools + + with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"): + # Allowed set asks for two registry tools (one present, one + # not) plus a bunch of middleware-provided names. + _filter_tools( + [search_surfsense_docs], + allowed_names={ + "search_surfsense_docs", + "scrape_webpage", # legitimately missing → should warn + "read_file", # mw-provided → suppressed + "ls", + "grep", + "glob", + "write_todos", + }, + ) + + warnings = [ + r.message for r in caplog.records if r.levelno >= logging.INFO + ] + # Exactly one warning, and it should mention scrape_webpage but not + # any middleware-provided name. Inspect the rendered "missing" + # list (between the brackets) so we don't false-match substrings + # like ``ls`` inside ``available``. + assert len(warnings) == 1, warnings + msg = warnings[0] + assert "scrape_webpage" in msg + bracket_section = msg.split("missing: ", 1)[1] + for noisy in ("read_file", "ls", "grep", "glob", "write_todos"): + assert f"'{noisy}'" not in bracket_section, msg + + +class TestDenyPatternsCoverage: + def test_deny_patterns_cover_canonical_write_tools(self) -> None: + canonical_writes = [ + "update_memory", + "edit_file", + "write_file", + "move_file", + "mkdir", + "linear_create_issue", + "linear_update_issue", + "linear_delete_issue", + "slack_send_message", + "create_index", + "update_account", + "delete_record", + "send_email", + ] + for tool_name in canonical_writes: + assert any( + _wildcard_matches(pattern, tool_name) + for pattern in WRITE_TOOL_DENY_PATTERNS + ), f"no deny pattern matches {tool_name!r}" + + def test_deny_patterns_do_not_match_safe_read_tools(self) -> None: + canonical_reads = [ + "search_surfsense_docs", + "read_file", + "ls_tree", + "grep", + "web_search", + "scrape_webpage", + "get_connected_accounts", + "generate_report", + ] + for tool_name in canonical_reads: + assert not any( + _wildcard_matches(pattern, tool_name) + for pattern in WRITE_TOOL_DENY_PATTERNS + ), f"deny pattern incorrectly matches read tool {tool_name!r}" + + +def _wildcard_matches(pattern: str, value: str) -> bool: + """Helper using the same matcher the rule evaluator does.""" + from app.agents.new_chat.permissions import wildcard_match + + return wildcard_match(value, pattern) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py new file mode 100644 index 000000000..f792aef60 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py @@ -0,0 +1,103 @@ +"""Tests for ToolCallNameRepairMiddleware.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage + +from app.agents.new_chat.middleware.tool_call_repair import ( + ToolCallNameRepairMiddleware, +) +from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME + +pytestmark = pytest.mark.unit + + +def _make_state(message: AIMessage) -> dict: + return {"messages": [message]} + + +class _FakeRuntime: + def __init__(self, context: object | None = None) -> None: + self.context = context + + +class TestRepair: + def test_passthrough_when_name_matches(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage(content="", tool_calls=[ + {"name": "echo", "args": {}, "id": "1"}, + ]) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is None # no change + + def test_lowercase_repair(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage(content="", tool_calls=[ + {"name": "Echo", "args": {"x": 1}, "id": "1"}, + ]) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is not None + repaired = out["messages"][0] + assert repaired.tool_calls[0]["name"] == "echo" + + def test_invalid_fallback_when_no_match(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo", INVALID_TOOL_NAME}, + fuzzy_match_threshold=None, + ) + msg = AIMessage(content="", tool_calls=[ + {"name": "totally_different_name", "args": {"k": "v"}, "id": "1"}, + ]) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is not None + repaired_call = out["messages"][0].tool_calls[0] + assert repaired_call["name"] == INVALID_TOOL_NAME + assert repaired_call["args"]["tool"] == "totally_different_name" + assert "totally_different_name" in repaired_call["args"]["error"] + + def test_no_invalid_means_skip_when_unknown(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage(content="", tool_calls=[ + {"name": "unknown", "args": {}, "id": "1"}, + ]) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + # No repair available; original returned unchanged (no update) + assert out is None + + def test_fuzzy_match_works_when_enabled(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"search_documents"}, + fuzzy_match_threshold=0.7, + ) + msg = AIMessage(content="", tool_calls=[ + {"name": "search_docments", "args": {}, "id": "1"}, + ]) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is not None + assert out["messages"][0].tool_calls[0]["name"] == "search_documents" + + def test_skips_when_no_messages(self) -> None: + mw = ToolCallNameRepairMiddleware(registered_tool_names={"echo"}) + out = mw.after_model({"messages": []}, _FakeRuntime()) + assert out is None + + def test_runtime_context_extends_registered(self) -> None: + from types import SimpleNamespace + + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage(content="", tool_calls=[ + {"name": "DynamicTool", "args": {}, "id": "1"}, + ]) + runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"])) + out = mw.after_model(_make_state(msg), runtime) + assert out is not None + assert out["messages"][0].tool_calls[0]["name"] == "dynamictool" diff --git a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py index add0105e4..467ba6d5f 100644 --- a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py +++ b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py @@ -1,8 +1,10 @@ import pytest from langchain_core.messages import AIMessage +from langchain_core.tools import StructuredTool from app.agents.new_chat.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, + wrap_dedup_key_by_arg_name, ) pytestmark = pytest.mark.unit @@ -14,9 +16,34 @@ def _make_state(tool_calls: list[dict]) -> dict: return {"messages": [msg]} +def _hitl_tool(name: str, *, dedup_arg: str) -> StructuredTool: + """Build a tool with declarative ``dedup_key`` metadata. + + Mirrors the ``ToolDefinition.dedup_key`` -> ``tool.metadata["dedup_key"]`` + propagation done by :func:`build_tools` after the cleanup tier. + """ + + def _fn(**kwargs): + return "ok" + + return StructuredTool.from_function( + func=_fn, + name=name, + description="x", + metadata={"dedup_key": wrap_dedup_key_by_arg_name(dedup_arg)}, + ) + + def test_duplicate_hitl_calls_reduced_to_first(): - """When the LLM emits the same HITL tool call twice, only the first is kept.""" - mw = DedupHITLToolCallsMiddleware() + """When the LLM emits the same HITL tool call twice, only the first is kept. + + After the cleanup tier removed ``_NATIVE_HITL_TOOL_DEDUP_KEYS``, the + resolver is sourced from ``ToolDefinition.dedup_key`` propagated onto + ``tool.metadata`` — which the registry does at agent build time. The + test mirrors that wiring with an in-memory tool. + """ + tool = _hitl_tool("delete_calendar_event", dedup_arg="event_title_or_id") + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) state = _make_state( [ diff --git a/surfsense_backend/tests/unit/observability/__init__.py b/surfsense_backend/tests/unit/observability/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/tests/unit/observability/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/tests/unit/observability/test_otel.py b/surfsense_backend/tests/unit/observability/test_otel.py new file mode 100644 index 000000000..583142098 --- /dev/null +++ b/surfsense_backend/tests/unit/observability/test_otel.py @@ -0,0 +1,84 @@ +"""Tests for the SurfSense OpenTelemetry shim (Tier 3b).""" + +from __future__ import annotations + +import pytest + +from app.observability import otel + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _reset_otel_state(monkeypatch: pytest.MonkeyPatch): + """Force a clean OTel disabled state per test, then restore after.""" + for env in ("OTEL_EXPORTER_OTLP_ENDPOINT", "SURFSENSE_DISABLE_OTEL"): + monkeypatch.delenv(env, raising=False) + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + otel.reload_for_tests() + yield + otel.reload_for_tests() + + +def test_disabled_by_default_when_no_endpoint() -> None: + assert otel.is_enabled() is False + + +def test_enabled_when_endpoint_configured(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + assert otel.reload_for_tests() is True + + +def test_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + assert otel.reload_for_tests() is False + + +class TestNoopSpansWhenDisabled: + def test_generic_span_yields_noop(self) -> None: + with otel.span("any.thing", attributes={"x": 1}) as sp: + sp.set_attribute("y", 2) + sp.set_attributes({"a": "b"}) + sp.add_event("evt") + sp.record_exception(RuntimeError("ignored")) + sp.set_status("ignored") + # Reaching here without raising means the no-op is well-formed + + def test_exception_propagates_through_span(self) -> None: + with pytest.raises(ValueError), otel.span("err"): + raise ValueError("boom") + + def test_each_helper_is_a_no_op_when_disabled(self) -> None: + helpers = [ + otel.tool_call_span("write_file", input_size=42), + otel.model_call_span(model_id="openai:gpt-4o", provider="openai"), + otel.kb_search_span(search_space_id=1, query_chars=99), + otel.kb_persist_span(document_type="NOTE", document_id=7), + otel.compaction_span(reason="overflow", messages_in=120), + otel.interrupt_span(interrupt_type="permission_ask"), + otel.permission_asked_span(permission="edit", pattern="/x/**"), + ] + for cm in helpers: + with cm as sp: + assert sp is not None + sp.set_attribute("ok", True) + + +class TestEnabledIntegration: + """When OTel is wired but no SDK exporter is bound, the API still works.""" + + def test_span_attaches_attributes(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Use the API tracer (no-op-ish but real Span objects). + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + assert otel.reload_for_tests() is True + + # Should not raise even when set_attributes/record_exception fall through + # to an SDK that isn't actually installed. + with otel.tool_call_span("scrape_webpage", input_size=10) as sp: + sp.set_attribute("tool.output.size", 200) + sp.set_attribute("tool.truncated", False) + with otel.model_call_span(model_id="m", provider="p") as sp: + sp.set_attribute("retry.count", 3) diff --git a/surfsense_backend/tests/unit/services/test_revert_service.py b/surfsense_backend/tests/unit/services/test_revert_service.py new file mode 100644 index 000000000..cb8443291 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_revert_service.py @@ -0,0 +1,56 @@ +"""Unit tests for the agent revert service (Tier 5.3).""" + +from __future__ import annotations + +from typing import Any + +from app.services.revert_service import can_revert + + +class _FakeAction: + def __init__(self, *, user_id: Any, tool_name: str = "edit_file") -> None: + self.user_id = user_id + self.tool_name = tool_name + + +class TestCanRevert: + def test_owner_can_revert_their_own_action(self) -> None: + action = _FakeAction(user_id="user-123") + assert can_revert( + requester_user_id="user-123", action=action, is_admin=False + ) + + def test_other_user_cannot_revert(self) -> None: + action = _FakeAction(user_id="user-123") + assert not can_revert( + requester_user_id="someone-else", action=action, is_admin=False + ) + + def test_admin_always_allowed(self) -> None: + action = _FakeAction(user_id="user-123") + assert can_revert( + requester_user_id="anybody", action=action, is_admin=True + ) + + def test_admin_can_revert_anonymous_action(self) -> None: + action = _FakeAction(user_id=None) + assert can_revert( + requester_user_id="admin", action=action, is_admin=True + ) + + def test_anonymous_action_blocks_non_admin(self) -> None: + action = _FakeAction(user_id=None) + assert not can_revert( + requester_user_id="user-1", action=action, is_admin=False + ) + + def test_uuid_string_normalization(self) -> None: + """``user_id`` may be a UUID object; comparison should still work.""" + import uuid + + u = uuid.uuid4() + action = _FakeAction(user_id=u) + # Same UUID, passed as string from the requesting side. + assert can_revert( + requester_user_id=str(u), action=action, is_admin=False + ) diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx new file mode 100644 index 000000000..b01f556ad --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx @@ -0,0 +1,451 @@ +"use client"; + +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; +import { AlertTriangle, Check, Plus, ShieldCheck, Trash2, X } from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; +import { toast } from "sonner"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Spinner } from "@/components/ui/spinner"; +import { + type AgentPermissionAction, + type AgentPermissionRule, + type AgentPermissionRuleCreate, + agentPermissionsApiService, +} from "@/lib/apis/agent-permissions-api.service"; +import { AppError } from "@/lib/error"; +import { formatRelativeDate } from "@/lib/format-date"; +import { cn } from "@/lib/utils"; + +const ACTION_DESCRIPTIONS: Record = { + allow: "Always run without prompting", + deny: "Block silently", + ask: "Pause and ask for approval", +}; + +const ACTION_BADGE: Record = { + allow: { label: "Allow", className: "bg-emerald-500/10 text-emerald-600 border-emerald-500/30" }, + deny: { label: "Deny", className: "bg-destructive/10 text-destructive border-destructive/30" }, + ask: { label: "Ask", className: "bg-amber-500/10 text-amber-600 border-amber-500/30" }, +}; + +const EMPTY_FORM: AgentPermissionRuleCreate = { + permission: "", + pattern: "*", + action: "ask", + user_id: null, + thread_id: null, +}; + +function permissionRulesQueryKey(searchSpaceId: number) { + return ["agent-permission-rules", searchSpaceId] as const; +} + +function ScopeBadge({ rule }: { rule: AgentPermissionRule }) { + if (rule.thread_id !== null) { + return ( + + Thread #{rule.thread_id} + + ); + } + if (rule.user_id !== null) { + return ( + + User-specific + + ); + } + return ( + + Search space + + ); +} + +export function AgentPermissionsContent() { + const searchSpaceIdRaw = useAtomValue(activeSearchSpaceIdAtom); + const searchSpaceId = searchSpaceIdRaw ? Number(searchSpaceIdRaw) : null; + + const { data: flags } = useAtomValue(agentFlagsAtom); + const featureEnabled = !!flags?.enable_permission && !flags?.disable_new_agent_stack; + + const queryClient = useQueryClient(); + + const { + data: rules, + isLoading, + isError, + error, + } = useQuery({ + queryKey: searchSpaceId + ? permissionRulesQueryKey(searchSpaceId) + : ["agent-permission-rules", "none"], + queryFn: () => agentPermissionsApiService.list(searchSpaceId as number), + enabled: !!searchSpaceId && featureEnabled, + staleTime: 60 * 1000, + }); + + const createMutation = useMutation({ + mutationFn: (payload: AgentPermissionRuleCreate) => + agentPermissionsApiService.create(searchSpaceId as number, payload), + onSuccess: () => { + toast.success("Rule created."); + queryClient.invalidateQueries({ + queryKey: permissionRulesQueryKey(searchSpaceId as number), + }); + }, + onError: (err: unknown) => { + toast.error(err instanceof Error ? err.message : "Failed to create rule."); + }, + }); + + const updateMutation = useMutation({ + mutationFn: (params: { ruleId: number; action: AgentPermissionAction; pattern?: string }) => + agentPermissionsApiService.update(searchSpaceId as number, params.ruleId, { + action: params.action, + pattern: params.pattern, + }), + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: permissionRulesQueryKey(searchSpaceId as number), + }); + }, + onError: (err: unknown) => { + toast.error(err instanceof Error ? err.message : "Failed to update rule."); + }, + }); + + const deleteMutation = useMutation({ + mutationFn: (ruleId: number) => + agentPermissionsApiService.remove(searchSpaceId as number, ruleId), + onSuccess: () => { + toast.success("Rule deleted."); + queryClient.invalidateQueries({ + queryKey: permissionRulesQueryKey(searchSpaceId as number), + }); + }, + onError: (err: unknown) => { + toast.error(err instanceof Error ? err.message : "Failed to delete rule."); + }, + }); + + const [showForm, setShowForm] = useState(false); + const [formData, setFormData] = useState(EMPTY_FORM); + const [deleteTarget, setDeleteTarget] = useState(null); + + const sortedRules = useMemo(() => rules ?? [], [rules]); + + const handleCreate = useCallback(async () => { + if (!formData.permission.trim()) { + toast.error("Permission is required."); + return; + } + try { + await createMutation.mutateAsync({ + ...formData, + permission: formData.permission.trim(), + pattern: formData.pattern.trim() || "*", + }); + setShowForm(false); + setFormData(EMPTY_FORM); + } catch (err) { + if (err instanceof AppError && err.message) { + // already toasted by onError + } + } + }, [createMutation, formData]); + + const handleConfirmDelete = useCallback(async () => { + if (deleteTarget === null) return; + try { + await deleteMutation.mutateAsync(deleteTarget); + } finally { + setDeleteTarget(null); + } + }, [deleteMutation, deleteTarget]); + + if (!featureEnabled) { + return ( + + + Permission middleware is disabled + + Flip{" "} + SURFSENSE_ENABLE_PERMISSION on + the backend to manage allow/deny/ask rules from this panel. + + + ); + } + + if (!searchSpaceId) { + return ( +

Open a search space to manage agent rules.

+ ); + } + + if (isLoading) { + return ( +
+ +
+ ); + } + + if (isError) { + return ( +
+ +

Failed to load rules

+

+ {error instanceof Error ? error.message : "Unknown error."} +

+
+ ); + } + + return ( +
+
+
+

+ Tell the agent which tools to allow, deny, or ask before running. Rules use wildcard + patterns and are evaluated at the most specific scope first. +

+
+ {!showForm && ( + + )} +
+ + {showForm && ( +
+
+

New permission rule

+ +
+
+ + setFormData((p) => ({ ...p, permission: e.target.value }))} + /> +

+ Match a tool capability. Use * for wildcards. +

+
+ +
+ + setFormData((p) => ({ ...p, pattern: e.target.value }))} + /> +

+ Wildcard against the canonical argument (e.g. prod-*). +

+
+
+ +
+ + +

+ {ACTION_DESCRIPTIONS[formData.action]} +

+
+ +
+ + +
+
+
+ )} + + {sortedRules.length === 0 && !showForm && ( +
+ +

No rules yet

+

+ Without rules the agent uses the deployment default for every tool. +

+
+ )} + + {sortedRules.length > 0 && ( +
+ {sortedRules.map((rule) => { + const badge = ACTION_BADGE[rule.action]; + const isUpdating = + updateMutation.isPending && updateMutation.variables?.ruleId === rule.id; + const isDeleting = deleteMutation.isPending && deleteMutation.variables === rule.id; + + return ( +
+
+
+
+ + {rule.permission} + + {rule.pattern !== "*" && ( + + → {rule.pattern} + + )} + +
+

+ Created {formatRelativeDate(rule.created_at)} +

+
+ +
+ + + +
+
+
+ ); + })} +
+ )} + + !open && setDeleteTarget(null)} + > + + + Delete this rule? + + The agent will fall back to deployment defaults for matching tool calls. + + + + Cancel + { + e.preventDefault(); + handleConfirmDelete(); + }} + disabled={deleteMutation.isPending} + > + {deleteMutation.isPending ? "Deleting…" : "Delete"} + + + + +
+ ); +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx new file mode 100644 index 000000000..bd8f03a70 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx @@ -0,0 +1,309 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { CircleCheck, CircleSlash, Cog, RotateCcw } from "lucide-react"; +import { useMemo } from "react"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import type { AgentFeatureFlags } from "@/lib/apis/agent-flags-api.service"; +import { cn } from "@/lib/utils"; + +type FlagKey = keyof AgentFeatureFlags; + +interface FlagDef { + key: FlagKey; + label: string; + description: string; + envVar: string; +} + +interface FlagGroup { + id: string; + title: string; + subtitle: string; + flags: FlagDef[]; +} + +const FLAG_GROUPS: FlagGroup[] = [ + { + id: "tier1", + title: "Tier 1 — Agent quality", + subtitle: "Context editing, retries, fallbacks, doom-loop, tool-call repair.", + flags: [ + { + key: "enable_context_editing", + label: "Context editing", + description: "Trim tool outputs and spill old text into backend storage.", + envVar: "SURFSENSE_ENABLE_CONTEXT_EDITING", + }, + { + key: "enable_compaction_v2", + label: "Compaction v2", + description: "SurfSense-aware compaction replacing safe summarization.", + envVar: "SURFSENSE_ENABLE_COMPACTION_V2", + }, + { + key: "enable_retry_after", + label: "Retry-After", + description: "Honour rate-limit retry-after headers automatically.", + envVar: "SURFSENSE_ENABLE_RETRY_AFTER", + }, + { + key: "enable_model_fallback", + label: "Model fallback", + description: "Fail over to a backup model on persistent errors.", + envVar: "SURFSENSE_ENABLE_MODEL_FALLBACK", + }, + { + key: "enable_model_call_limit", + label: "Model call limit", + description: "Cap total model calls per turn to prevent budget run-aways.", + envVar: "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", + }, + { + key: "enable_tool_call_limit", + label: "Tool call limit", + description: "Cap total tool calls per turn.", + envVar: "SURFSENSE_ENABLE_TOOL_CALL_LIMIT", + }, + { + key: "enable_tool_call_repair", + label: "Tool-call name repair", + description: "Recover from lower-cased / fuzzy tool names emitted by smaller models.", + envVar: "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", + }, + { + key: "enable_doom_loop", + label: "Doom-loop detection", + description: "Detect repeated identical tool calls and ask the user to confirm.", + envVar: "SURFSENSE_ENABLE_DOOM_LOOP", + }, + ], + }, + { + id: "tier2", + title: "Tier 2 — Safety", + subtitle: "Permission rules, busy-mutex, smarter tool selection.", + flags: [ + { + key: "enable_permission", + label: "Permission middleware", + description: "Apply allow/deny/ask rules from the Agent Permissions tab.", + envVar: "SURFSENSE_ENABLE_PERMISSION", + }, + { + key: "enable_busy_mutex", + label: "Busy mutex", + description: "Prevent two concurrent runs from corrupting the same thread.", + envVar: "SURFSENSE_ENABLE_BUSY_MUTEX", + }, + { + key: "enable_llm_tool_selector", + label: "LLM tool selector", + description: "Use a smaller model to pre-filter the tool list per turn.", + envVar: "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", + }, + ], + }, + { + id: "tier4", + title: "Tier 4 — Skills + subagents", + subtitle: "Built-in skills, specialized subagents, KB planner runnable.", + flags: [ + { + key: "enable_skills", + label: "Skills", + description: "Load on-demand skill packs (kb-research, report-writing, …).", + envVar: "SURFSENSE_ENABLE_SKILLS", + }, + { + key: "enable_specialized_subagents", + label: "Specialized subagents", + description: "Spin up explore / report_writer / connector_negotiator subagents.", + envVar: "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", + }, + { + key: "enable_kb_planner_runnable", + label: "KB planner runnable", + description: "Compile a private planner sub-agent for KB search.", + envVar: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", + }, + ], + }, + { + id: "tier5", + title: "Tier 5 — Audit + revert", + subtitle: "Action log + revert route used by the Agent Actions sheet.", + flags: [ + { + key: "enable_action_log", + label: "Action log", + description: "Persist every tool call to agent_action_log.", + envVar: "SURFSENSE_ENABLE_ACTION_LOG", + }, + { + key: "enable_revert_route", + label: "Revert route", + description: "Allow reverting reversible actions from the action log.", + envVar: "SURFSENSE_ENABLE_REVERT_ROUTE", + }, + ], + }, + { + id: "tier6", + title: "Tier 6 — Plugins", + subtitle: "Optional middleware loaded from entry points.", + flags: [ + { + key: "enable_plugin_loader", + label: "Plugin loader", + description: "Load surfsense.plugins entry-point middleware.", + envVar: "SURFSENSE_ENABLE_PLUGIN_LOADER", + }, + ], + }, + { + id: "obs", + title: "Observability", + subtitle: "Telemetry pipelines (orthogonal to feature gating).", + flags: [ + { + key: "enable_otel", + label: "OpenTelemetry", + description: "Emit OTel spans (also requires OTEL_EXPORTER_OTLP_ENDPOINT).", + envVar: "SURFSENSE_ENABLE_OTEL", + }, + ], + }, +]; + +function FlagRow({ def, value }: { def: FlagDef; value: boolean }) { + return ( +
+
+
+ {def.label} + + {def.envVar} + +
+

{def.description}

+
+ + {value ? : } + {value ? "On" : "Off"} + +
+ ); +} + +export function AgentStatusContent() { + const { data: flags, isLoading, isError, error, refetch } = useAtomValue(agentFlagsAtom); + + const enabledCount = useMemo(() => { + if (!flags) return 0; + return Object.entries(flags).filter(([k, v]) => k !== "disable_new_agent_stack" && v === true) + .length; + }, [flags]); + + if (isLoading) { + return ( +
+ + + +
+ ); + } + + if (isError || !flags) { + return ( + + Failed to load agent status + + {error instanceof Error ? error.message : "Unknown error."} + + + + ); + } + + const masterOff = flags.disable_new_agent_stack; + + return ( +
+ {masterOff ? ( + + + Master kill-switch is on + + + SURFSENSE_DISABLE_NEW_AGENT_STACK=true + + forces every new middleware off, regardless of the individual flags below. Restart the + backend after changing it. + + + ) : ( + + + + Agent stack + + {enabledCount} on + + + + Read-only mirror of the backend's AgentFeatureFlags. Flip an env var and + restart the backend to change a value. + + + )} + + {FLAG_GROUPS.map((group, groupIdx) => { + const allOff = group.flags.every((f) => !flags[f.key]); + return ( +
+ {groupIdx > 0 && } +
+
+
+

{group.title}

+

{group.subtitle}

+
+ {allOff && ( + + all off + + )} +
+
+ {group.flags.map((def) => ( + + ))} +
+
+
+ ); + })} +
+ ); +} diff --git a/surfsense_web/atoms/agent/action-log-sheet.atom.ts b/surfsense_web/atoms/agent/action-log-sheet.atom.ts new file mode 100644 index 000000000..f88d3ed1e --- /dev/null +++ b/surfsense_web/atoms/agent/action-log-sheet.atom.ts @@ -0,0 +1,19 @@ +import { atom } from "jotai"; + +interface ActionLogSheetState { + open: boolean; + threadId: number | null; +} + +export const actionLogSheetAtom = atom({ + open: false, + threadId: null, +}); + +export const openActionLogSheetAtom = atom(null, (_get, set, threadId: number) => { + set(actionLogSheetAtom, { open: true, threadId }); +}); + +export const closeActionLogSheetAtom = atom(null, (_get, set) => { + set(actionLogSheetAtom, { open: false, threadId: null }); +}); diff --git a/surfsense_web/atoms/agent/agent-flags-query.atom.ts b/surfsense_web/atoms/agent/agent-flags-query.atom.ts new file mode 100644 index 000000000..30158deaa --- /dev/null +++ b/surfsense_web/atoms/agent/agent-flags-query.atom.ts @@ -0,0 +1,17 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { agentFlagsApiService } from "@/lib/apis/agent-flags-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; + +export const AGENT_FLAGS_QUERY_KEY = ["agent", "flags"] as const; + +/** + * Reads the backend agent feature flags. Cached for the lifetime of the + * page (flags only change on backend restart) so we can drive UI gating + * without re-hitting the API. + */ +export const agentFlagsAtom = atomWithQuery(() => ({ + queryKey: AGENT_FLAGS_QUERY_KEY, + staleTime: 10 * 60 * 1000, + enabled: !!getBearerToken(), + queryFn: () => agentFlagsApiService.get(), +})); diff --git a/surfsense_web/components/agent-action-log/action-log-button.tsx b/surfsense_web/components/agent-action-log/action-log-button.tsx new file mode 100644 index 000000000..1c0383136 --- /dev/null +++ b/surfsense_web/components/agent-action-log/action-log-button.tsx @@ -0,0 +1,50 @@ +"use client"; + +import { useAtomValue, useSetAtom } from "jotai"; +import { Activity } from "lucide-react"; +import { useCallback } from "react"; +import { openActionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { Button } from "@/components/ui/button"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; + +interface ActionLogButtonProps { + threadId: number | null; +} + +/** + * Header button that opens the agent action log sheet for the current + * thread. Renders nothing when: + * - the action log feature flag is off (graceful no-op for older + * deployments), OR + * - there is no active thread (lazy-created chats haven't started). + */ +export function ActionLogButton({ threadId }: ActionLogButtonProps) { + const { data: flags } = useAtomValue(agentFlagsAtom); + const open = useSetAtom(openActionLogSheetAtom); + + const enabled = !!flags?.enable_action_log && !flags?.disable_new_agent_stack; + + const handleClick = useCallback(() => { + if (threadId !== null) open(threadId); + }, [open, threadId]); + + if (!enabled || threadId === null) return null; + + return ( + + + + + Agent actions + + ); +} diff --git a/surfsense_web/components/agent-action-log/action-log-item.tsx b/surfsense_web/components/agent-action-log/action-log-item.tsx new file mode 100644 index 000000000..425714c1f --- /dev/null +++ b/surfsense_web/components/agent-action-log/action-log-item.tsx @@ -0,0 +1,215 @@ +"use client"; + +import { ChevronRight, RotateCcw, ShieldOff, Undo2 } from "lucide-react"; +import { useState } from "react"; +import { toast } from "sonner"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Separator } from "@/components/ui/separator"; +import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; +import { formatRelativeDate } from "@/lib/format-date"; +import { cn } from "@/lib/utils"; + +function formatToolName(name: string): string { + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + +interface ActionLogItemProps { + action: AgentAction; + threadId: number; + onRevertSuccess: () => void; +} + +export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogItemProps) { + const [isExpanded, setIsExpanded] = useState(false); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + + const isAlreadyReverted = action.reverted_by_action_id !== null; + const isRevertAction = action.is_revert_action; + const hasError = action.error !== null && action.error !== undefined; + + const Icon = getToolIcon(action.tool_name); + const displayName = formatToolName(action.tool_name); + + const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null; + const truncatedArgs = + argsPreview && argsPreview.length > 600 ? `${argsPreview.slice(0, 600)}…` : argsPreview; + + const canRevert = action.reversible && !isAlreadyReverted && !isRevertAction && !hasError; + + const handleRevert = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revert(threadId, action.id); + toast.success(response.message || "Action reverted successfully."); + onRevertSuccess(); + } catch (err) { + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert action."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( +
+ + + {isExpanded && ( +
+ {truncatedArgs && ( +
+

+ Arguments +

+
+								{truncatedArgs}
+							
+
+ )} + {action.error && ( +
+

+ Error +

+
+								{JSON.stringify(action.error, null, 2)}
+							
+
+ )} + {action.reverse_descriptor && ( +
+

+ Reverse plan +

+
+								{JSON.stringify(action.reverse_descriptor, null, 2)}
+							
+
+ )} + + + +
+

+ Action ID: {action.id} +

+ {canRevert ? ( + + + + + + + Revert this action? + + This will undo {displayName} and append a + new audit entry. The agent's chat history is preserved — only the tool's + effects on your knowledge base or connectors will be reversed where possible. + + + + Cancel + { + e.preventDefault(); + handleRevert(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert"} + + + + + ) : ( +
+ + {isAlreadyReverted + ? "Already reverted" + : isRevertAction + ? "Revert entry" + : hasError + ? "Cannot revert errored action" + : "Not reversible"} +
+ )} +
+
+ )} +
+ ); +} diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx new file mode 100644 index 000000000..68d2ffef3 --- /dev/null +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -0,0 +1,185 @@ +"use client"; + +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useAtom, useAtomValue } from "jotai"; +import { Activity, RefreshCcw } from "lucide-react"; +import { useCallback, useMemo } from "react"; +import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Separator } from "@/components/ui/separator"; +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetTitle, +} from "@/components/ui/sheet"; +import { Skeleton } from "@/components/ui/skeleton"; +import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { ActionLogItem } from "./action-log-item"; + +const ACTION_LOG_PAGE_SIZE = 50; + +function actionLogQueryKey(threadId: number) { + return ["agent-actions", threadId] as const; +} + +function EmptyState() { + return ( +
+
+ +
+
+

No actions logged yet

+

+ Once the agent calls a tool in this thread, it will show up here. From the log you can + inspect arguments and revert reversible actions. +

+
+
+ ); +} + +function DisabledState() { + return ( +
+
+ +
+
+

Action log is disabled

+

+ This deployment hasn't enabled the agent action log. An admin can flip + + SURFSENSE_ENABLE_ACTION_LOG + + . +

+
+
+ ); +} + +const SKELETON_KEYS = ["s1", "s2", "s3", "s4"] as const; + +function LoadingState() { + return ( +
+ {SKELETON_KEYS.map((key) => ( + + ))} +
+ ); +} + +export function ActionLogSheet() { + const [state, setState] = useAtom(actionLogSheetAtom); + const queryClient = useQueryClient(); + + const { data: flags } = useAtomValue(agentFlagsAtom); + const actionLogEnabled = !!flags?.enable_action_log && !flags?.disable_new_agent_stack; + const revertEnabled = !!flags?.enable_revert_route && !flags?.disable_new_agent_stack; + + const threadId = state.threadId; + + const { data, isLoading, isFetching, isError, error, refetch } = useQuery({ + queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"], + queryFn: () => + agentActionsApiService.listForThread(threadId as number, { + page: 0, + pageSize: ACTION_LOG_PAGE_SIZE, + }), + enabled: state.open && threadId !== null && actionLogEnabled, + staleTime: 15 * 1000, + }); + + const handleRevertSuccess = useCallback(() => { + if (threadId !== null) { + queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) }); + } + }, [queryClient, threadId]); + + const items = useMemo(() => data?.items ?? [], [data]); + + return ( + setState((s) => ({ ...s, open }))}> + + +
+
+ + Agent actions + {data?.total !== undefined && data.total > 0 && ( + + {data.total} + + )} +
+ +
+ + Audit trail of every tool call the agent made in this thread. + {revertEnabled + ? " Reversible actions can be undone in place." + : " Reverts are read-only on this deployment."} + +
+ + + +
+ {!actionLogEnabled ? ( + + ) : threadId === null ? ( + + ) : isLoading ? ( + + ) : isError ? ( +
+

Failed to load actions

+

+ {error instanceof Error ? error.message : "Unknown error"} +

+ +
+ ) : items.length === 0 ? ( + + ) : ( +
+ {items.map((action) => ( + + ))} + {data?.has_more && ( +

+ Showing {items.length} of {data.total}. Older actions are paginated. +

+ )} +
+ )} +
+
+
+ ); +} diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 8bb228580..7655e10cc 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -85,10 +85,13 @@ function preprocessMarkdown(content: string): string { } ); + // All math forms are normalised to $$...$$ so we can disable single-dollar + // inline math in remark-math (otherwise currency like "$3,120.00 and $0.00" + // gets parsed as a LaTeX expression). // 1. Block math: \[...\] → $$...$$ content = content.replace(/\\\[([\s\S]*?)\\\]/g, (_, inner) => `$$${inner}$$`); - // 2. Inline math: \(...\) → $...$ - content = content.replace(/\\\(([\s\S]*?)\\\)/g, (_, inner) => `$${inner}$`); + // 2. Inline math: \(...\) → $$...$$ + content = content.replace(/\\\(([\s\S]*?)\\\)/g, (_, inner) => `$$${inner}$$`); // 3. Block: \begin{equation}...\end{equation} → $$...$$ content = content.replace( /\\begin\{equation\}([\s\S]*?)\\end\{equation\}/g, @@ -99,8 +102,11 @@ function preprocessMarkdown(content: string): string { /\\begin\{displaymath\}([\s\S]*?)\\end\{displaymath\}/g, (_, inner) => `$$${inner}$$` ); - // 5. Inline: \begin{math}...\end{math} → $...$ - content = content.replace(/\\begin\{math\}([\s\S]*?)\\end\{math\}/g, (_, inner) => `$${inner}$`); + // 5. Inline: \begin{math}...\end{math} → $$...$$ + content = content.replace( + /\\begin\{math\}([\s\S]*?)\\end\{math\}/g, + (_, inner) => `$$${inner}$$` + ); // 6. Strip backtick wrapping around math: `$$...$$` → $$...$$ and `$...$` → $...$ content = content.replace(/`(\${1,2})((?:(?!\1).)+)\1`/g, "$1$2$1"); @@ -180,7 +186,7 @@ const MarkdownTextImpl = () => { return ( { if (isInterruptResult(props.result)) { + if (isDoomLoopInterrupt(props.result)) { + return ; + } return ; } return ; diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index aecf55a27..3efdab03b 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -28,6 +28,7 @@ import { import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { SearchSpaceSettingsDialog } from "@/components/settings/search-space-settings-dialog"; import { TeamDialog } from "@/components/settings/team-dialog"; +import { ActionLogSheet } from "@/components/agent-action-log/action-log-sheet"; import { UserSettingsDialog } from "@/components/settings/user-settings-dialog"; import { AlertDialog, @@ -909,6 +910,9 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid + + {/* Agent action log + revert sheet */} + ); } diff --git a/surfsense_web/components/layout/ui/header/Header.tsx b/surfsense_web/components/layout/ui/header/Header.tsx index ec54cb901..f49d7fb88 100644 --- a/surfsense_web/components/layout/ui/header/Header.tsx +++ b/surfsense_web/components/layout/ui/header/Header.tsx @@ -5,6 +5,7 @@ import { usePathname } from "next/navigation"; import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { activeTabAtom, tabsAtom } from "@/atoms/tabs/tabs.atom"; +import { ActionLogButton } from "@/components/agent-action-log/action-log-button"; import { ChatHeader } from "@/components/new-chat/chat-header"; import { ChatShareButton } from "@/components/new-chat/chat-share-button"; import { useIsMobile } from "@/hooks/use-mobile"; @@ -69,6 +70,7 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { {/* Right side - Actions */}
+ {hasThread && } {hasThread && ( )} diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index 5775fe083..c4d73e30b 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -10,7 +10,11 @@ const code = createCodePlugin({ }); const math = createMathPlugin({ - singleDollarTextMath: true, + // Disabled so currency like "$3,120.00 and ... $0.00" isn't parsed as + // inline LaTeX. convertLatexDelimiters() below normalises any genuine + // inline math (\(...\), $...$ starting with a LaTeX command, etc.) to + // $$...$$, so this flip doesn't lose any math rendering. + singleDollarTextMath: false, }); interface MarkdownViewerProps { diff --git a/surfsense_web/components/settings/user-settings-dialog.tsx b/surfsense_web/components/settings/user-settings-dialog.tsx index 6740aad92..a04ce16dd 100644 --- a/surfsense_web/components/settings/user-settings-dialog.tsx +++ b/surfsense_web/components/settings/user-settings-dialog.tsx @@ -2,6 +2,7 @@ import { useAtom } from "jotai"; import { + Activity, Brain, CircleUser, Globe, @@ -9,6 +10,7 @@ import { KeyRound, Monitor, ReceiptText, + ShieldCheck, Sparkles, } from "lucide-react"; import dynamic from "next/dynamic"; @@ -74,6 +76,20 @@ const MemoryContent = dynamic( ), { ssr: false } ); +const AgentPermissionsContent = dynamic( + () => + import( + "@/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent" + ).then((m) => ({ default: m.AgentPermissionsContent })), + { ssr: false } +); +const AgentStatusContent = dynamic( + () => + import("@/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent").then( + (m) => ({ default: m.AgentStatusContent }) + ), + { ssr: false } +); export function UserSettingsDialog() { const t = useTranslations("userSettings"); @@ -103,6 +119,16 @@ export function UserSettingsDialog() { label: "Memory", icon: , }, + { + value: "agent-permissions", + label: "Agent Permissions", + icon: , + }, + { + value: "agent-status", + label: "Agent Status", + icon: , + }, { value: "purchases", label: "Purchase History", @@ -141,6 +167,8 @@ export function UserSettingsDialog() { {state.initialTab === "prompts" && } {state.initialTab === "community-prompts" && } {state.initialTab === "memory" && } + {state.initialTab === "agent-permissions" && } + {state.initialTab === "agent-status" && } {state.initialTab === "purchases" && } {state.initialTab === "desktop" && } {state.initialTab === "desktop-shortcuts" && } diff --git a/surfsense_web/components/tool-ui/doom-loop-approval.tsx b/surfsense_web/components/tool-ui/doom-loop-approval.tsx new file mode 100644 index 000000000..6132a71ed --- /dev/null +++ b/surfsense_web/components/tool-ui/doom-loop-approval.tsx @@ -0,0 +1,187 @@ +"use client"; + +import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; +import { CornerDownLeftIcon, OctagonAlert } from "lucide-react"; +import { useCallback, useEffect, useMemo } from "react"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Separator } from "@/components/ui/separator"; +import { useHitlPhase } from "@/hooks/use-hitl-phase"; +import type { HitlDecision, InterruptResult } from "@/lib/hitl"; +import { isInterruptResult, useHitlDecision } from "@/lib/hitl"; + +/** + * Specialized HITL card for ``DoomLoopMiddleware`` interrupts. The + * backend signals these by setting ``context.permission === "doom_loop"`` + * on the ``permission_ask`` interrupt. + * + * The card replaces the generic "approve/reject" framing with a + * "continue/stop" affordance that better matches the user's mental + * model: the agent is stuck repeating itself, not asking permission + * for a destructive action. + */ +function DoomLoopCard({ + toolName, + args, + interruptData, + onDecision, +}: { + toolName: string; + args: Record; + interruptData: InterruptResult; + onDecision: (decision: HitlDecision) => void; +}) { + const { phase, setProcessing, setRejected } = useHitlPhase(interruptData); + + const context = (interruptData.context ?? {}) as Record; + const threshold = typeof context.threshold === "number" ? context.threshold : 3; + const stuckTool = (typeof context.tool === "string" && context.tool) || toolName; + const recentSignatures = Array.isArray(context.recent_signatures) + ? (context.recent_signatures as string[]) + : []; + const displayName = stuckTool.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); + + const argPreview = useMemo(() => { + if (!args || Object.keys(args).length === 0) return null; + try { + const json = JSON.stringify(args, null, 2); + return json.length > 600 ? `${json.slice(0, 600)}…` : json; + } catch { + return null; + } + }, [args]); + + const handleContinue = useCallback(() => { + if (phase !== "pending") return; + setProcessing(); + onDecision({ type: "approve" }); + }, [phase, setProcessing, onDecision]); + + const handleStop = useCallback(() => { + if (phase !== "pending") return; + setRejected(); + onDecision({ type: "reject", message: "Doom loop: user requested stop." }); + }, [phase, setRejected, onDecision]); + + useEffect(() => { + const handler = (e: KeyboardEvent) => { + if (phase !== "pending") return; + if (e.key === "Enter" && !e.shiftKey && !e.ctrlKey && !e.metaKey) { + e.preventDefault(); + handleStop(); + } + }; + window.addEventListener("keydown", handler); + return () => window.removeEventListener("keydown", handler); + }, [phase, handleStop]); + + const isResolved = phase === "complete" || phase === "rejected"; + + return ( + + + + + {phase === "rejected" + ? "Stopped" + : phase === "processing" + ? "Continuing…" + : phase === "complete" + ? "Continued" + : "I might be stuck"} + + {!isResolved && ( + + doom-loop + + )} + + + {phase === "processing" ? ( + + ) : phase === "rejected" ? ( +

+ I stopped retrying {displayName} as you asked. +

+ ) : phase === "complete" ? ( +

+ Continuing to call {displayName} as you asked. +

+ ) : ( +

+ I called {displayName} {threshold} times in a row + with similar arguments. Should I keep going or stop and rethink? +

+ )} + + {argPreview && phase === "pending" && ( + <> + +
+

+ Last arguments +

+
+								{argPreview}
+							
+
+ + )} + + {recentSignatures.length > 0 && phase === "pending" && ( +
+ + Show repeated signatures ({recentSignatures.length}) + +
    + {recentSignatures.map((sig) => ( +
  • + {sig} +
  • + ))} +
+
+ )} + + {phase === "pending" && ( +
+ + +
+ )} +
+
+ ); +} + +export const DoomLoopApprovalToolUI: ToolCallMessagePartComponent = ({ + toolName, + args, + result, +}) => { + const { dispatch } = useHitlDecision(); + + if (!result || !isInterruptResult(result)) return null; + + return ( + } + interruptData={result} + onDecision={(decision) => dispatch([decision])} + /> + ); +}; + +export function isDoomLoopInterrupt(result: unknown): boolean { + if (!isInterruptResult(result)) return false; + const ctx = (result.context ?? {}) as Record; + return ctx.permission === "doom_loop"; +} diff --git a/surfsense_web/lib/apis/agent-actions-api.service.ts b/surfsense_web/lib/apis/agent-actions-api.service.ts new file mode 100644 index 000000000..007bb131e --- /dev/null +++ b/surfsense_web/lib/apis/agent-actions-api.service.ts @@ -0,0 +1,64 @@ +import { z } from "zod"; +import { baseApiService } from "./base-api.service"; + +const AgentActionReadSchema = z.object({ + id: z.number(), + thread_id: z.number(), + user_id: z.string().nullable(), + search_space_id: z.number(), + tool_name: z.string(), + args: z.record(z.string(), z.unknown()).nullable(), + result_id: z.string().nullable(), + reversible: z.boolean(), + reverse_descriptor: z.record(z.string(), z.unknown()).nullable(), + error: z.record(z.string(), z.unknown()).nullable(), + reverse_of: z.number().nullable(), + reverted_by_action_id: z.number().nullable(), + is_revert_action: z.boolean(), + created_at: z.string(), +}); + +export type AgentAction = z.infer; + +const AgentActionListResponseSchema = z.object({ + items: z.array(AgentActionReadSchema), + total: z.number(), + page: z.number(), + page_size: z.number(), + has_more: z.boolean(), +}); + +export type AgentActionListResponse = z.infer; + +const RevertResponseSchema = z.object({ + status: z.literal("ok"), + message: z.string(), + new_action_id: z.number().nullable().optional(), +}); + +export type RevertResponse = z.infer; + +class AgentActionsApiService { + listForThread = async ( + threadId: number, + opts: { page?: number; pageSize?: number } = {} + ): Promise => { + const params = new URLSearchParams(); + params.set("page", String(opts.page ?? 0)); + params.set("page_size", String(opts.pageSize ?? 50)); + return baseApiService.get( + `/api/v1/threads/${threadId}/actions?${params.toString()}`, + AgentActionListResponseSchema + ); + }; + + revert = async (threadId: number, actionId: number): Promise => { + return baseApiService.post( + `/api/v1/threads/${threadId}/revert/${actionId}`, + RevertResponseSchema, + { body: {} } + ); + }; +} + +export const agentActionsApiService = new AgentActionsApiService(); diff --git a/surfsense_web/lib/apis/agent-flags-api.service.ts b/surfsense_web/lib/apis/agent-flags-api.service.ts new file mode 100644 index 000000000..87332ca9f --- /dev/null +++ b/surfsense_web/lib/apis/agent-flags-api.service.ts @@ -0,0 +1,40 @@ +import { z } from "zod"; +import { baseApiService } from "./base-api.service"; + +const AgentFeatureFlagsSchema = z.object({ + disable_new_agent_stack: z.boolean(), + + enable_context_editing: z.boolean(), + enable_compaction_v2: z.boolean(), + enable_retry_after: z.boolean(), + enable_model_fallback: z.boolean(), + enable_model_call_limit: z.boolean(), + enable_tool_call_limit: z.boolean(), + enable_tool_call_repair: z.boolean(), + enable_doom_loop: z.boolean(), + + enable_permission: z.boolean(), + enable_busy_mutex: z.boolean(), + enable_llm_tool_selector: z.boolean(), + + enable_skills: z.boolean(), + enable_specialized_subagents: z.boolean(), + enable_kb_planner_runnable: z.boolean(), + + enable_action_log: z.boolean(), + enable_revert_route: z.boolean(), + + enable_plugin_loader: z.boolean(), + + enable_otel: z.boolean(), +}); + +export type AgentFeatureFlags = z.infer; + +class AgentFlagsApiService { + get = async (): Promise => { + return baseApiService.get(`/api/v1/agent/flags`, AgentFeatureFlagsSchema); + }; +} + +export const agentFlagsApiService = new AgentFlagsApiService(); diff --git a/surfsense_web/lib/apis/agent-permissions-api.service.ts b/surfsense_web/lib/apis/agent-permissions-api.service.ts new file mode 100644 index 000000000..6927c55d0 --- /dev/null +++ b/surfsense_web/lib/apis/agent-permissions-api.service.ts @@ -0,0 +1,90 @@ +import { z } from "zod"; +import { ValidationError } from "@/lib/error"; +import { baseApiService } from "./base-api.service"; + +const ActionEnum = z.enum(["allow", "deny", "ask"]); +export type AgentPermissionAction = z.infer; + +const AgentPermissionRuleSchema = z.object({ + id: z.number(), + search_space_id: z.number(), + user_id: z.string().nullable(), + thread_id: z.number().nullable(), + permission: z.string(), + pattern: z.string(), + action: ActionEnum, + created_at: z.string(), +}); + +export type AgentPermissionRule = z.infer; + +const AgentPermissionRuleListSchema = z.array(AgentPermissionRuleSchema); + +const AgentPermissionRuleCreateSchema = z.object({ + permission: z + .string() + .min(1, "Permission is required") + .max(255) + .regex(/^[a-zA-Z0-9_:.\-*]+$/, "Use letters, digits, '.', '_', ':', '-', or '*' wildcards."), + pattern: z.string().min(1).max(255).default("*"), + action: ActionEnum, + user_id: z.string().nullable().optional(), + thread_id: z.number().nullable().optional(), +}); + +export type AgentPermissionRuleCreate = z.infer; + +const AgentPermissionRuleUpdateSchema = z.object({ + pattern: z.string().min(1).max(255).optional(), + action: ActionEnum.optional(), +}); + +export type AgentPermissionRuleUpdate = z.infer; + +class AgentPermissionsApiService { + list = async (searchSpaceId: number): Promise => { + return baseApiService.get( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules`, + AgentPermissionRuleListSchema + ); + }; + + create = async ( + searchSpaceId: number, + payload: AgentPermissionRuleCreate + ): Promise => { + const parsed = AgentPermissionRuleCreateSchema.safeParse(payload); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((i) => i.message).join(", ")); + } + return baseApiService.post( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules`, + AgentPermissionRuleSchema, + { body: parsed.data } + ); + }; + + update = async ( + searchSpaceId: number, + ruleId: number, + payload: AgentPermissionRuleUpdate + ): Promise => { + const parsed = AgentPermissionRuleUpdateSchema.safeParse(payload); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((i) => i.message).join(", ")); + } + return baseApiService.patch( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules/${ruleId}`, + AgentPermissionRuleSchema, + { body: parsed.data } + ); + }; + + remove = async (searchSpaceId: number, ruleId: number): Promise => { + await baseApiService.delete( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules/${ruleId}` + ); + }; +} + +export const agentPermissionsApiService = new AgentPermissionsApiService(); From 76c91adebc0b30102e0d6df026f62b47716d3ac2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 04:12:42 +0530 Subject: [PATCH 216/299] refactor(mentions): replace sidebarSelectedDocumentsAtom with mentionedDocumentsAtom and introduce getMentionDocKey utility for consistent document key generation --- .../atoms/chat/mentioned-documents.atom.ts | 23 ------ .../assistant-ui/inline-mention-editor.tsx | 82 ++++++++++++------- .../components/assistant-ui/thread.tsx | 53 ++++++------ .../layout/ui/sidebar/DocumentsSidebar.tsx | 29 ++++--- surfsense_web/lib/chat/mention-doc-key.ts | 8 ++ 5 files changed, 102 insertions(+), 93 deletions(-) create mode 100644 surfsense_web/lib/chat/mention-doc-key.ts diff --git a/surfsense_web/atoms/chat/mentioned-documents.atom.ts b/surfsense_web/atoms/chat/mentioned-documents.atom.ts index 47401995d..9c4546237 100644 --- a/surfsense_web/atoms/chat/mentioned-documents.atom.ts +++ b/surfsense_web/atoms/chat/mentioned-documents.atom.ts @@ -9,29 +9,6 @@ import type { Document } from "@/contracts/types/document.types"; */ export const mentionedDocumentsAtom = atom[]>([]); -/** - * Back-compat alias for sidebar checkbox selection. - * This now points to mentionedDocumentsAtom so the app has a single source - * of truth for mentioned/selected documents. - */ -export const sidebarSelectedDocumentsAtom = atom< - Pick[], - [ - | Pick[] - | (( - prev: Pick[] - ) => Pick[]), - ], - void ->( - (get) => get(mentionedDocumentsAtom), - (get, set, update) => { - const prev = get(mentionedDocumentsAtom); - const next = typeof update === "function" ? update(prev) : update; - set(mentionedDocumentsAtom, next); - } -); - /** * Derived read-only atom that maps deduplicated mentioned docs * into backend payload fields. diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index e75a840c0..05277f508 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -14,6 +14,7 @@ import { import { renderToStaticMarkup } from "react-dom/server"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { cn } from "@/lib/utils"; function renderElementToHTML(element: ReactElement): string { @@ -57,7 +58,6 @@ interface InlineMentionEditorProps { onKeyDown?: (e: React.KeyboardEvent) => void; disabled?: boolean; className?: string; - initialDocuments?: MentionedDocument[]; initialText?: string; } @@ -109,7 +109,6 @@ export const InlineMentionEditor = forwardRef(null); const [isEmpty, setIsEmpty] = useState(true); const [mentionedDocs, setMentionedDocs] = useState>( - () => new Map(initialDocuments.map((d) => [`${d.document_type ?? "UNKNOWN"}:${d.id}`, d])) + () => new Map() ); const isComposingRef = useRef(false); const lastSelectionRangeRef = useRef(null); + const isRangeInsideEditor = useCallback((range: Range | null): range is Range => { + if (!range || !editorRef.current) return false; + return ( + editorRef.current.contains(range.startContainer) && + editorRef.current.contains(range.endContainer) + ); + }, []); const isSelectionInsideEditor = useCallback( (selection: Selection | null): selection is Selection => { if (!selection || selection.rangeCount === 0 || !editorRef.current) return false; const range = selection.getRangeAt(0); - return editorRef.current.contains(range.startContainer); + return isRangeInsideEditor(range); }, - [] + [isRangeInsideEditor] ); const rememberSelection = useCallback(() => { @@ -139,11 +145,11 @@ export const InlineMentionEditor = forwardRef { const selection = window.getSelection(); if (!selection) return null; - if (!lastSelectionRangeRef.current) return selection; + if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null; selection.removeAllRanges(); selection.addRange(lastSelectionRangeRef.current.cloneRange()); return selection; - }, []); + }, [isRangeInsideEditor]); useEffect(() => { const handleSelectionChange = () => { @@ -154,23 +160,13 @@ export const InlineMentionEditor = forwardRef document.removeEventListener("selectionchange", handleSelectionChange); }, [rememberSelection]); - - // Sync initial documents - useEffect(() => { - if (initialDocuments.length > 0) { - setMentionedDocs( - new Map(initialDocuments.map((d) => [`${d.document_type ?? "UNKNOWN"}:${d.id}`, d])) - ); - } - }, [initialDocuments]); - useEffect(() => { if (!initialText || !editorRef.current) return; editorRef.current.innerText = initialText; editorRef.current.appendChild(document.createElement("br")); editorRef.current.appendChild(document.createElement("br")); setIsEmpty(false); - onChange?.(initialText, initialDocuments); + onChange?.(initialText, []); editorRef.current.focus(); const sel = window.getSelection(); const range = document.createRange(); @@ -182,7 +178,7 @@ export const InlineMentionEditor = forwardRef { @@ -284,7 +280,7 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(docKey); @@ -358,7 +354,7 @@ export const InlineMentionEditor = forwardRef new Map(prev).set(docKey, mentionDoc)); const nextDocs = new Map(mentionedDocs); nextDocs.set(docKey, mentionDoc); @@ -367,12 +363,33 @@ export const InlineMentionEditor = forwardRef { if (!editorRef.current) return; - const chipKey = `${docType ?? "UNKNOWN"}:${docId}`; + const chipKey = getMentionDocKey({ id: docId, document_type: docType }); const chips = editorRef.current.querySelectorAll( `span[${CHIP_DATA_ATTR}="true"]` ); @@ -696,7 +712,10 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(chipKey); @@ -734,7 +753,10 @@ export const InlineMentionEditor = forwardRef { const next = new Map(prev); next.delete(chipKey); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index dcc068bd1..f9e5ca7fb 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -87,6 +87,7 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsSync } from "@/hooks/use-comments-sync"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; import { cn } from "@/lib/utils"; @@ -338,6 +339,9 @@ const Composer: FC = () => { const [mentionQuery, setMentionQuery] = useState(""); const [actionQuery, setActionQuery] = useState(""); const editorRef = useRef(null); + const prevMentionedDocsRef = useRef< + Map> + >(new Map()); const documentPickerRef = useRef(null); const promptPickerRef = useRef(null); const viewportRef = useRef(null); @@ -633,51 +637,50 @@ const Composer: FC = () => { const handleDocumentsMention = useCallback( (documents: Pick[]) => { - const existingKeys = new Set(mentionedDocuments.map((d) => `${d.document_type}:${d.id}`)); - const newDocs = documents.filter( - (doc) => !existingKeys.has(`${doc.document_type}:${doc.id}`) - ); + const editorMentionedDocs = editorRef.current?.getMentionedDocuments() ?? []; + const editorDocKeys = new Set(editorMentionedDocs.map((doc) => getMentionDocKey(doc))); - for (const doc of newDocs) { + for (const doc of documents) { + const key = getMentionDocKey(doc); + if (editorDocKeys.has(key)) continue; editorRef.current?.insertDocumentChip(doc); } setMentionedDocuments((prev) => { - const existingKeySet = new Set(prev.map((d) => `${d.document_type}:${d.id}`)); - const uniqueNewDocs = documents.filter( - (doc) => !existingKeySet.has(`${doc.document_type}:${doc.id}`) - ); + const existingKeySet = new Set(prev.map((d) => getMentionDocKey(d))); + const uniqueNewDocs = documents.filter((doc) => !existingKeySet.has(getMentionDocKey(doc))); return [...prev, ...uniqueNewDocs]; }); setMentionQuery(""); }, - [mentionedDocuments, setMentionedDocuments] + [setMentionedDocuments] ); useEffect(() => { const editor = editorRef.current; - if (!editor) return; + const nextDocsMap = new Map(mentionedDocuments.map((doc) => [getMentionDocKey(doc), doc])); + const prevDocsMap = prevMentionedDocsRef.current; - const toKey = (doc: { id: number; document_type?: string }) => - `${doc.document_type ?? "UNKNOWN"}:${doc.id}`; - - const atomDocs = mentionedDocuments; - const editorDocs = editor.getMentionedDocuments(); - const atomKeys = new Set(atomDocs.map(toKey)); - const editorKeys = new Set(editorDocs.map(toKey)); - - for (const doc of atomDocs) { - if (!editorKeys.has(toKey(doc))) { - editor.insertDocumentChip(doc, { removeTriggerText: false }); - } + if (!editor) { + prevMentionedDocsRef.current = nextDocsMap; + return; } - for (const doc of editorDocs) { - if (!atomKeys.has(toKey(doc))) { + const editorKeys = new Set(editor.getMentionedDocuments().map(getMentionDocKey)); + + for (const [key, doc] of nextDocsMap) { + if (prevDocsMap.has(key) || editorKeys.has(key)) continue; + editor.insertDocumentChip(doc, { removeTriggerText: false }); + } + + for (const [key, doc] of prevDocsMap) { + if (!nextDocsMap.has(key)) { editor.removeDocumentChip(doc.id, doc.document_type); } } + + prevMentionedDocsRef.current = nextDocsMap; }, [mentionedDocuments]); return ( diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 3c5a64b0e..63b6dc1b7 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -24,7 +24,7 @@ import type React from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { - sidebarSelectedDocumentsAtom, + mentionedDocumentsAtom, } from "@/atoms/chat/mentioned-documents.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; @@ -74,6 +74,7 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useMediaQuery } from "@/hooks/use-media-query"; import { usePlatform, useElectronAPI } from "@/hooks/use-platform"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service"; @@ -414,7 +415,7 @@ function AuthenticatedDocumentsSidebarBase({ }, [refreshWatchedIds]); const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom); - const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom); + const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom); const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]); // Folder state @@ -859,12 +860,12 @@ function AuthenticatedDocumentsSidebarBase({ const handleToggleChatMention = useCallback( (doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => { - const key = `${doc.document_type}:${doc.id}`; + const key = getMentionDocKey(doc); if (isMentioned) { - setSidebarDocs((prev) => prev.filter((d) => `${d.document_type}:${d.id}` !== key)); + setSidebarDocs((prev) => prev.filter((d) => getMentionDocKey(d) !== key)); } else { setSidebarDocs((prev) => { - if (prev.some((d) => `${d.document_type}:${d.id}` === key)) return prev; + if (prev.some((d) => getMentionDocKey(d) === key)) return prev; return [ ...prev, { id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum }, @@ -895,9 +896,9 @@ function AuthenticatedDocumentsSidebarBase({ if (selectAll) { setSidebarDocs((prev) => { - const existingDocKeys = new Set(prev.map((d) => `${d.document_type}:${d.id}`)); + const existingDocKeys = new Set(prev.map((d) => getMentionDocKey(d))); const newDocs = subtreeDocs - .filter((d) => !existingDocKeys.has(`${d.document_type}:${d.id}`)) + .filter((d) => !existingDocKeys.has(getMentionDocKey(d))) .map((d) => ({ id: d.id, title: d.title, @@ -906,10 +907,8 @@ function AuthenticatedDocumentsSidebarBase({ return newDocs.length > 0 ? [...prev, ...newDocs] : prev; }); } else { - const keysToRemove = new Set(subtreeDocs.map((d) => `${d.document_type}:${d.id}`)); - setSidebarDocs((prev) => - prev.filter((d) => !keysToRemove.has(`${d.document_type}:${d.id}`)) - ); + const keysToRemove = new Set(subtreeDocs.map((d) => getMentionDocKey(d))); + setSidebarDocs((prev) => prev.filter((d) => !keysToRemove.has(getMentionDocKey(d)))); } }, [treeDocuments, foldersByParent, setSidebarDocs] @@ -1572,17 +1571,17 @@ function AnonymousDocumentsSidebar({ const [isUploading, setIsUploading] = useState(false); const [search, setSearch] = useState(""); - const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom); + const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom); const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]); const handleToggleChatMention = useCallback( (doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => { - const key = `${doc.document_type}:${doc.id}`; + const key = getMentionDocKey(doc); if (isMentioned) { - setSidebarDocs((prev) => prev.filter((d) => `${d.document_type}:${d.id}` !== key)); + setSidebarDocs((prev) => prev.filter((d) => getMentionDocKey(d) !== key)); } else { setSidebarDocs((prev) => { - if (prev.some((d) => `${d.document_type}:${d.id}` === key)) return prev; + if (prev.some((d) => getMentionDocKey(d) === key)) return prev; return [ ...prev, { id: doc.id, title: doc.title, document_type: doc.document_type as DocumentTypeEnum }, diff --git a/surfsense_web/lib/chat/mention-doc-key.ts b/surfsense_web/lib/chat/mention-doc-key.ts new file mode 100644 index 000000000..5dfa11ea3 --- /dev/null +++ b/surfsense_web/lib/chat/mention-doc-key.ts @@ -0,0 +1,8 @@ +type MentionKeyInput = { + id: number; + document_type?: string | null; +}; + +export function getMentionDocKey(doc: MentionKeyInput): string { + return `${doc.document_type ?? "UNKNOWN"}:${doc.id}`; +} From 8be7f2e05c3bd0451da855536d59e2f02c9d27c4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 04:19:07 +0530 Subject: [PATCH 217/299] refactor(mentions): update document mention handling to use document keys for consistency across components --- surfsense_web/components/assistant-ui/thread.tsx | 11 ++++++++--- .../components/documents/FolderTreeView.tsx | 13 +++++++------ .../layout/ui/sidebar/DocumentsSidebar.tsx | 14 ++++++++++---- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index f9e5ca7fb..3964d60e5 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -628,9 +628,14 @@ const Composer: FC = () => { const handleDocumentRemove = useCallback( (docId: number, docType?: string) => { - setMentionedDocuments((prev) => - prev.filter((doc) => !(doc.id === docId && doc.document_type === docType)) - ); + setMentionedDocuments((prev) => { + if (!docType) { + // Defensive fallback: keep UI in sync even when chip type is unavailable. + return prev.filter((doc) => doc.id !== docId); + } + const removedKey = getMentionDocKey({ id: docId, document_type: docType }); + return prev.filter((doc) => getMentionDocKey(doc) !== removedKey); + }); }, [setMentionedDocuments] ); diff --git a/surfsense_web/components/documents/FolderTreeView.tsx b/surfsense_web/components/documents/FolderTreeView.tsx index 9b7a393d8..2063fbee5 100644 --- a/surfsense_web/components/documents/FolderTreeView.tsx +++ b/surfsense_web/components/documents/FolderTreeView.tsx @@ -7,6 +7,7 @@ import { DndProvider } from "react-dnd"; import { HTML5Backend } from "react-dnd-html5-backend"; import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { DocumentNode, type DocumentNodeDoc } from "./DocumentNode"; import { type FolderDisplay, FolderNode } from "./FolderNode"; @@ -17,7 +18,7 @@ interface FolderTreeViewProps { documents: DocumentNodeDoc[]; expandedIds: Set; onToggleExpand: (folderId: number) => void; - mentionedDocIds: Set; + mentionedDocKeys: Set; onToggleChatMention: ( doc: { id: number; title: string; document_type: string }, isMentioned: boolean @@ -62,7 +63,7 @@ export function FolderTreeView({ documents, expandedIds, onToggleExpand, - mentionedDocIds, + mentionedDocKeys, onToggleChatMention, onToggleFolderSelect, onRenameFolder, @@ -181,7 +182,7 @@ export function FolderTreeView({ function compute(folderId: number): { selected: number; total: number } { const directDocs = (docsByFolder[folderId] ?? []).filter(isSelectable); - let selected = directDocs.filter((d) => mentionedDocIds.has(d.id)).length; + let selected = directDocs.filter((d) => mentionedDocKeys.has(getMentionDocKey(d))).length; let total = directDocs.length; for (const child of foldersByParent[folderId] ?? []) { @@ -202,7 +203,7 @@ export function FolderTreeView({ if (states[f.id] === undefined) compute(f.id); } return states; - }, [folders, docsByFolder, foldersByParent, mentionedDocIds]); + }, [folders, docsByFolder, foldersByParent, mentionedDocKeys]); const folderMap = useMemo(() => { const map: Record = {}; @@ -276,7 +277,7 @@ export function FolderTreeView({ key={`doc-${d.id}`} doc={d} depth={depth} - isMentioned={mentionedDocIds.has(d.id)} + isMentioned={mentionedDocKeys.has(getMentionDocKey(d))} onToggleChatMention={onToggleChatMention} onPreview={onPreviewDocument} onEdit={onEditDocument} @@ -356,7 +357,7 @@ export function FolderTreeView({ key={`doc-${d.id}`} doc={d} depth={depth} - isMentioned={mentionedDocIds.has(d.id)} + isMentioned={mentionedDocKeys.has(getMentionDocKey(d))} onToggleChatMention={onToggleChatMention} onPreview={onPreviewDocument} onEdit={onEditDocument} diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 63b6dc1b7..6ff087b9b 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -416,7 +416,10 @@ function AuthenticatedDocumentsSidebarBase({ const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom); const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom); - const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]); + const mentionedDocKeys = useMemo( + () => new Set(sidebarDocs.map((d) => getMentionDocKey(d))), + [sidebarDocs] + ); // Folder state const [expandedFolderMap, setExpandedFolderMap] = useAtom(expandedFolderIdsAtom); @@ -1143,7 +1146,7 @@ function AuthenticatedDocumentsSidebarBase({ documents={searchFilteredDocuments} expandedIds={expandedIds} onToggleExpand={toggleFolderExpand} - mentionedDocIds={mentionedDocIds} + mentionedDocKeys={mentionedDocKeys} onToggleChatMention={handleToggleChatMention} onToggleFolderSelect={handleToggleFolderSelect} onRenameFolder={handleRenameFolder} @@ -1572,7 +1575,10 @@ function AnonymousDocumentsSidebar({ const [search, setSearch] = useState(""); const [sidebarDocs, setSidebarDocs] = useAtom(mentionedDocumentsAtom); - const mentionedDocIds = useMemo(() => new Set(sidebarDocs.map((d) => d.id)), [sidebarDocs]); + const mentionedDocKeys = useMemo( + () => new Set(sidebarDocs.map((d) => getMentionDocKey(d))), + [sidebarDocs] + ); const handleToggleChatMention = useCallback( (doc: { id: number; title: string; document_type: string }, isMentioned: boolean) => { @@ -1801,7 +1807,7 @@ function AnonymousDocumentsSidebar({ documents={searchFilteredDocs} expandedIds={new Set()} onToggleExpand={() => {}} - mentionedDocIds={mentionedDocIds} + mentionedDocKeys={mentionedDocKeys} onToggleChatMention={handleToggleChatMention} onToggleFolderSelect={() => {}} onRenameFolder={() => gate("rename folders")} From b9a66cb417d04bd445b6be1a7838a2278ae3cefe Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Tue, 28 Apr 2026 21:30:53 -0700 Subject: [PATCH 218/299] feat: various UI fixes, prompt optimizations, and allowing duplicate docs - Updated `content_hash` in the `Document` model to remove global uniqueness, allowing identical content across different paths. - Enhanced `_create_document` function to handle path uniqueness and prevent session-poisoning from `IntegrityError`. - Added detailed comments for clarity on the changes and their implications. - Introduced new citation handling in the editor for improved user experience with citation jumps. - Updated package dependencies in the frontend for better functionality. --- .../133_drop_documents_content_hash_unique.py | 107 +++ .../new_chat/middleware/kb_persistence.py | 65 +- .../app/agents/new_chat/prompts/composer.py | 41 +- .../new_chat/prompts/providers/anthropic.md | 21 +- .../new_chat/prompts/providers/deepseek.md | 18 + .../new_chat/prompts/providers/google.md | 20 +- .../agents/new_chat/prompts/providers/grok.md | 17 + .../agents/new_chat/prompts/providers/kimi.md | 21 + .../prompts/providers/openai_classic.md | 22 +- .../prompts/providers/openai_codex.md | 19 + .../prompts/providers/openai_reasoning.md | 22 +- surfsense_backend/app/db.py | 10 +- .../agents/new_chat/prompts/test_composer.py | 74 +- .../test_kb_persistence_filesystem_parity.py | 168 ++++ surfsense_web/app/globals.css | 21 + .../pending-chunk-highlight.atom.ts | 19 + .../assistant-ui/inline-citation.tsx | 228 +++++- .../components/editor-panel/editor-panel.tsx | 530 +++++++++++-- .../components/editor/plate-editor.tsx | 31 + surfsense_web/components/editor/presets.ts | 28 + .../new-chat/source-detail-panel.tsx | 719 ------------------ .../settings/user-settings-dialog.tsx | 3 - .../components/ui/search-highlight-node.tsx | 45 ++ surfsense_web/lib/citation-search.ts | 125 +++ surfsense_web/package.json | 1 + surfsense_web/pnpm-lock.yaml | 17 + 26 files changed, 1540 insertions(+), 852 deletions(-) create mode 100644 surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/grok.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md create mode 100644 surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md create mode 100644 surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py create mode 100644 surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts delete mode 100644 surfsense_web/components/new-chat/source-detail-panel.tsx create mode 100644 surfsense_web/components/ui/search-highlight-node.tsx create mode 100644 surfsense_web/lib/citation-search.ts diff --git a/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py b/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py new file mode 100644 index 000000000..88c3e203f --- /dev/null +++ b/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py @@ -0,0 +1,107 @@ +"""133_drop_documents_content_hash_unique + +Revision ID: 133 +Revises: 132 +Create Date: 2026-04-29 + +Drop the global UNIQUE constraint on ``documents.content_hash`` so the +new-chat agent's ``write_file`` flow can persist legitimate file copies +(two paths, identical content) without hitting a constraint that mirrors +no real filesystem semantic. + +Path uniqueness still lives on ``documents.unique_identifier_hash`` (per +search space), which is the right invariant — exactly like an inode at a +given path on a POSIX filesystem. + +The non-unique INDEX on ``content_hash`` is preserved so connector +indexers' "have we seen this content before?" lookup +(:func:`app.tasks.document_processors.base.check_duplicate_document`, +which already uses ``.scalars().first()`` and is therefore tolerant of +duplicates) stays cheap. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import inspect + +from alembic import op + +revision: str = "133" +down_revision: str | None = "132" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def _existing_constraint_names(bind, table: str) -> set[str]: + inspector = inspect(bind) + return {c["name"] for c in inspector.get_unique_constraints(table)} + + +def _existing_index_names(bind, table: str) -> set[str]: + inspector = inspect(bind) + return {i["name"] for i in inspector.get_indexes(table)} + + +def upgrade() -> None: + bind = op.get_bind() + + # Both the named UniqueConstraint (added in revision 8) and the + # implicit-unique-index variant SQLAlchemy may emit need draining. + constraints = _existing_constraint_names(bind, "documents") + if "uq_documents_content_hash" in constraints: + op.drop_constraint( + "uq_documents_content_hash", "documents", type_="unique" + ) + + indexes = _existing_index_names(bind, "documents") + # Some Postgres versions surface the unique constraint via a unique + # index of the same name; check for that too. + for idx_name in ("uq_documents_content_hash",): + if idx_name in indexes: + op.drop_index(idx_name, table_name="documents") + + # Ensure the non-unique index is present for fast lookups. + if "ix_documents_content_hash" not in indexes: + op.create_index( + "ix_documents_content_hash", + "documents", + ["content_hash"], + unique=False, + ) + + +def downgrade() -> None: + bind = op.get_bind() + + # Re-applying UNIQUE is destructive: there may now be legitimate + # duplicates (e.g. two NOTE documents that share content because the + # user explicitly copied one to a new path). To avoid the migration + # silently deleting user data, we keep only the lowest-id row per + # content_hash — same strategy revision 8 used when first introducing + # the constraint. + op.execute( + """ + DELETE FROM documents + WHERE id NOT IN ( + SELECT MIN(id) + FROM documents + GROUP BY content_hash + ) + """ + ) + + indexes = _existing_index_names(bind, "documents") + if "ix_documents_content_hash" in indexes: + op.drop_index("ix_documents_content_hash", table_name="documents") + + op.create_index( + "ix_documents_content_hash", + "documents", + ["content_hash"], + unique=False, + ) + op.create_unique_constraint( + "uq_documents_content_hash", "documents", ["content_hash"] + ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py index 5682977d9..378b83950 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -28,6 +28,7 @@ from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.callbacks import dispatch_custom_event from langgraph.runtime import Runtime from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.filesystem_selection import FilesystemMode @@ -150,10 +151,11 @@ async def _create_document( virtual_path, search_space_id, ) - # Guard against the unique_identifier_hash constraint: another row at the - # same virtual_path (this search space) already owns the hash. Callers are - # expected to upsert via the wrapper, but this defends against bypasses - # and gives a clean ValueError instead of a session-poisoning IntegrityError. + # Filesystem-parity invariant: the only thing that *must* be unique is + # the path. Two notes can legitimately share content (e.g. ``cp a b``). + # Guard against the path-derived ``unique_identifier_hash`` constraint + # so we surface a clean ValueError instead of letting the INSERT poison + # the session with an IntegrityError. path_collision = await session.execute( select(Document.id).where( Document.search_space_id == search_space_id, @@ -165,17 +167,14 @@ async def _create_document( f"a document already exists at path '{virtual_path}' " "(unique_identifier_hash collision)" ) + # ``content_hash`` is intentionally NOT checked for uniqueness here. + # In a real filesystem two files at different paths can hold identical + # bytes, and the agent's ``write_file`` path needs that semantic to + # support copy/duplicate operations. The hash remains useful as a + # change-detection hint for connector indexers, which still consult it + # via :func:`check_duplicate_document` but do so with a non-unique + # lookup (``.first()``). content_hash = generate_content_hash(content, search_space_id) - content_collision = await session.execute( - select(Document.id).where( - Document.search_space_id == search_space_id, - Document.content_hash == content_hash, - ) - ) - if content_collision.scalar_one_or_none() is not None: - raise ValueError( - f"a document with identical content already exists for path '{virtual_path}'" - ) doc = Document( title=title, document_type=DocumentType.NOTE, @@ -493,19 +492,43 @@ async def commit_staged_filesystem_state( } ) else: + # Wrap each create in a SAVEPOINT so a residual + # ``IntegrityError`` (e.g. a deployment that hasn't run + # migration 133 yet, where ``documents.content_hash`` + # still carries its legacy global UNIQUE constraint) + # rolls back only this one create instead of poisoning + # the whole turn's transaction. try: - new_doc = await _create_document( - session, - virtual_path=path, - content=content, - search_space_id=search_space_id, - created_by_id=created_by_id, - ) + async with session.begin_nested(): + new_doc = await _create_document( + session, + virtual_path=path, + content=content, + search_space_id=search_space_id, + created_by_id=created_by_id, + ) except ValueError as exc: logger.warning( "kb_persistence: skipping %s create: %s", path, exc ) continue + except IntegrityError as exc: + # The path-uniqueness check above already protected + # against ``unique_identifier_hash`` collisions, so + # the most likely culprit is the legacy + # ``ix_documents_content_hash`` UNIQUE constraint + # that migration 133 drops. Log loudly so operators + # know to run the migration; do NOT silently swallow. + msg = str(exc.orig) if exc.orig is not None else str(exc) + logger.error( + "kb_persistence: IntegrityError creating %s: %s. " + "If this mentions content_hash, run alembic " + "upgrade to apply migration 133 which drops the " + "global UNIQUE constraint on documents.content_hash.", + path, + msg, + ) + continue doc_id_by_path[path] = new_doc.id committed_creates.append( { diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/agents/new_chat/prompts/composer.py index 44060f75f..bad033490 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/composer.py +++ b/surfsense_backend/app/agents/new_chat/prompts/composer.py @@ -38,12 +38,38 @@ from app.db import ChatVisibility # Provider variant detection # ----------------------------------------------------------------------------- -ProviderVariant = str # "anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default" +# String literal alias for the supported provider-specific prompt variants. +# When adding a new variant, also drop a matching ``providers/.md`` +# file in this package and (if appropriate) extend the regex matchers below. +# +# Stylistic clusters mirror OpenCode's prompt-per-family layout but adapted +# to SurfSense's "supplemental hints" architecture (each fragment is a +# focused style nudge, NOT a full system prompt — the main prompt is +# already assembled from base/ + tools/ + routing/). +ProviderVariant = str +# Known values: +# "anthropic" — Claude family (XML-friendly, narrative todos) +# "openai_reasoning" — GPT-5 / o-series (channel-aware pragmatic) +# "openai_classic" — GPT-4 family (autonomous persistence) +# "openai_codex" — gpt-*-codex (code-purist, terse, file:line refs) +# "google" — Gemini (formal, <3-line, numbered workflow) +# "kimi" — Moonshot Kimi-K* (action-bias, parallel tools) +# "grok" — xAI Grok (extreme-terse, one-word ok) +# "deepseek" — DeepSeek V3 / R1 (terse, R1-aware reasoning) +# "default" — fallback, no provider-specific block emitted +# IMPORTANT: order of evaluation matters in :func:`detect_provider_variant`. +# More specific patterns must come first (e.g. ``codex`` before +# ``openai_reasoning`` because codex model ids contain ``gpt``). + +_OPENAI_CODEX_RE = re.compile(r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE) _OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE) _OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE) _ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE) _GOOGLE_RE = re.compile(r"\bgemini\b", re.IGNORECASE) +_KIMI_RE = re.compile(r"\b(kimi[-\d.]*|moonshot)\b", re.IGNORECASE) +_GROK_RE = re.compile(r"\bgrok\b", re.IGNORECASE) +_DEEPSEEK_RE = re.compile(r"\bdeepseek\b", re.IGNORECASE) def detect_provider_variant(model_name: str | None) -> ProviderVariant: @@ -51,10 +77,17 @@ def detect_provider_variant(model_name: str | None) -> ProviderVariant: Heuristic match on the model id; returns ``"default"`` when nothing matches so the composer can fall back to the empty placeholder file. + + Order is significant: more-specific patterns are tried first so + ``gpt-5-codex`` routes to ``"openai_codex"`` rather than + ``"openai_reasoning"`` (mirrors OpenCode's + ``packages/opencode/src/session/system.ts`` dispatch). """ if not model_name: return "default" name = model_name.strip() + if _OPENAI_CODEX_RE.search(name): + return "openai_codex" if _OPENAI_REASONING_RE.search(name): return "openai_reasoning" if _OPENAI_CLASSIC_RE.search(name): @@ -63,6 +96,12 @@ def detect_provider_variant(model_name: str | None) -> ProviderVariant: return "anthropic" if _GOOGLE_RE.search(name): return "google" + if _KIMI_RE.search(name): + return "kimi" + if _GROK_RE.search(name): + return "grok" + if _DEEPSEEK_RE.search(name): + return "deepseek" return "default" diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md b/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md index 6e22ef265..f574da541 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md @@ -1,5 +1,20 @@ -You are running on an Anthropic Claude model. Use XML tags liberally to structure -intermediate reasoning when the task is complex. Prefer step-by-step plans inside -`` blocks before producing the final answer. +You are running on an Anthropic Claude model. + +Structured reasoning: +- Use XML tags liberally to organise intermediate reasoning when a task is non-trivial. `...` blocks are encouraged before tool calls or before producing a complex final answer. +- For multi-step requests, briefly outline a plan inside a `` block before issuing the first tool call. + +Professional objectivity: +- Prioritise technical accuracy over validating the user's beliefs. Provide direct, factual guidance without unnecessary superlatives, praise, or emotional validation. +- When uncertain, investigate (search the KB, fetch the page) rather than confirming the user's assumption. +- Disagree with the user when the evidence warrants it; respectful correction beats false agreement. + +Task management: +- For tasks with 3+ distinct steps use the todo / planning tool aggressively. Mark items in_progress before starting, completed immediately when finished — do not batch completions. +- Narrate progress through the todo list itself, not through chatty status lines. + +Tool calls: +- Run independent tool calls in parallel within one response. Sequence them only when a later call genuinely needs an earlier one's output. +- Never chain bash-like commands with `;` or `&&` to "narrate" — use prose between tool calls instead. diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md b/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md new file mode 100644 index 000000000..8acf008ca --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md @@ -0,0 +1,18 @@ + +You are running on a DeepSeek model (DeepSeek-V3 chat / DeepSeek-R1 reasoning). + +Reasoning hygiene (R1-aware): +- If the model surfaces explicit `` blocks, keep that internal scratch focused — do NOT restate the user's question inside it; jump straight to the analysis. +- Never paste the contents of `` into your final answer. Final answer should reflect only the conclusion, citations, and any user-facing rationale. +- Do not let chain-of-thought leak into tool-call arguments — keep tool inputs minimal and structural. + +Output style: +- Be concise. Default to a one-paragraph answer; expand only when the user asks for detail. +- Don't open with sycophantic phrasing ("Great question", "Sure, here you go"). Lead with the answer or the next action. +- For factual answers, cite once with `[citation:chunk_id]` and stop. + +Tool calls: +- Issue independent tool calls in parallel within a single turn. +- Prefer the knowledge-base search tools before any web-search; this model has strong recall but stale training data. +- Don't fabricate file paths, chunk ids, or URLs — only use values returned by tools or provided by the user. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/google.md b/surfsense_backend/app/agents/new_chat/prompts/providers/google.md index 4b31a8388..cac3b328b 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/providers/google.md +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/google.md @@ -1,4 +1,20 @@ -You are running on a Google Gemini model. Prefer concise, structured responses. -When using tools, follow the function-calling protocol and avoid verbose preludes. +You are running on a Google Gemini model. + +Output style: +- Concise & direct. Aim for fewer than 3 lines of prose (excluding tool output, citations, and code/snippets) when the task allows. +- No conversational filler — skip openers like "Okay, I will now…" and closers like "I have finished the changes…". Get straight to the action or answer. +- Format with GitHub-flavoured Markdown; assume monospace rendering. +- For one-line factual answers, just answer. No headers, no bullets. + +Workflow for non-trivial tasks (Understand → Plan → Act → Verify): +1. **Understand:** read the user's request and the relevant KB / connector context. Use search and read tools (in parallel when independent) before assuming anything. +2. **Plan:** when the task touches multiple steps, share an extremely concise plan first. +3. **Act:** call the appropriate tools, strictly adhering to the prompts/routing already established for this agent. +4. **Verify:** confirm with a follow-up read or search where it materially de-risks the answer. + +Discipline: +- Do not take significant actions beyond the clear scope of the user's request without confirming first. +- Do not assume a connector / tool / file exists — check (e.g. via `get_connected_accounts`) before referencing it. +- Path arguments must be the exact strings returned by tools; do not synthesise file paths. diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md b/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md new file mode 100644 index 000000000..95b8fcc14 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md @@ -0,0 +1,17 @@ + +You are running on an xAI Grok model. + +Maximum terseness: +- Answer in fewer than 4 lines unless the user asks for detail. One-word answers are best when they suffice. +- No preamble ("The answer is", "Here's what I'll do"), no postamble ("Hope that helps", "Let me know"). Get straight to the answer. +- Avoid restating the user's question. +- For factual lookups inside the knowledge base, give the answer with a single `[citation:chunk_id]` and stop. + +Tool discipline: +- Use exactly ONE tool per assistant turn when investigating; wait for the result before deciding the next call. Do not loop on the same tool with the same arguments — pick a result and act. +- For obviously parallelizable read-only batches (multiple independent searches), one turn with several tool calls is fine — but never chain into a fishing expedition. + +Style: +- No emojis unless the user asked. No nested bullets, no headers for short answers. +- If you can't help, say so in 1-2 sentences without explaining "why this could lead to…". + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md b/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md new file mode 100644 index 000000000..c3c11ad5e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md @@ -0,0 +1,21 @@ + +You are running on a Moonshot Kimi model (Kimi-K1.5 / Kimi-K2 / Kimi-K2.5+). + +Action bias: +- Default to taking action with tools rather than describing solutions in prose. If a tool can answer the question, call the tool. +- Don't narrate routine reads, searches, or obvious next steps. Combine related progress into one short status line. +- Be thorough in actions (test what you build, verify what you change). Be brief in explanations. + +Tool calls: +- Output multiple non-interfering tool calls in a SINGLE response — parallelism is a major efficiency win on this model. +- When the `task` tool is available, delegate focused subtasks to a subagent with full context (subagents don't inherit yours). +- Don't apologise or pre-announce tool calls. The tool call itself is self-explanatory. + +Language: +- Respond in the SAME language as the user's most recent turn unless explicitly instructed otherwise. + +Discipline: +- Stay on track. Never give the user more than what they asked for. +- Fact-check before stating anything as factual; don't fabricate citations. +- Keep it stupidly simple. Don't overcomplicate. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md index 7ea4366c4..9128609e0 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md @@ -1,5 +1,21 @@ -You are running on a classic OpenAI chat model (GPT-4 family). Use direct -function-calling for tools. When editing files, use the standard `edit_file` -or `write_file` tools rather than diff-based patches. +You are running on a classic OpenAI chat model (GPT-4 family). + +Persistence: +- Keep going until the user's query is completely resolved before yielding back. Don't end the turn at "I would do X" — actually do X. +- When you say "Next I will…" or "Now I will…", you MUST actually take that action in the same turn. +- If a tool call fails, diagnose and try again with corrected arguments; do not surface the raw error and stop. + +Planning: +- Plan extensively before each tool call and reflect briefly on the result of the previous call. For tasks with 3+ steps, use the todo / planning tool and mark items as `in_progress` / `completed` as you go. +- Always announce the next action in ONE concise sentence before making a non-trivial tool call ("I'll search the KB for the migration spec."). + +Output style: +- Conversational but professional. Plain prose for explanations, bullet points for findings, fenced code blocks (with language tags) for code. +- Don't dump tool output verbatim — summarise the relevant lines. +- Don't add a closing recap unless the user asked for one. After completing the work, just stop. + +Tool calls: +- Issue independent tool calls in parallel within one response. +- Use specialised tools over generic ones (e.g. KB search before web search; named connectors over MCP fallback). diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md new file mode 100644 index 000000000..6167d4b06 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md @@ -0,0 +1,19 @@ + +You are running on an OpenAI Codex-class model (gpt-codex / codex-mini / gpt-*-codex). + +Output style: +- Be concise. Don't dump fetched/searched content back at the user — reference paths or chunk ids instead. +- Reference sources as `path:line` (or `chunk:`) so they're clickable. Stand-alone paths per reference, even when repeated. +- Prefer numbered lists (`1.`, `2.`, `3.`) when offering options the user can pick by replying with a single number. +- Skip headers and heavy formatting for simple confirmations. +- No emojis, no em-dashes, no nested bullets. Single-level lists only. + +Code & structured-output tasks: +- Lead with a one-sentence explanation of the change before context. Don't open with "Summary:" — jump in. +- Suggest natural next steps (run tests, diff review, commit) only when they're genuinely the next move. +- For multi-line snippets use fenced code blocks with a language tag. + +Tool calls: +- Run independent tool calls in parallel; chain only when later calls need earlier results. +- Don't ask permission ("Should I proceed?") — proceed with the most reasonable default and state what you did. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md index 935d3f207..dd7a61536 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md @@ -1,5 +1,21 @@ -You are running on an OpenAI reasoning model (o-series / GPT-5+). Be terse and -direct in your responses. When editing files, prefer the `apply_patch` tool format -where available. Avoid restating the user request before answering. +You are running on an OpenAI reasoning model (GPT-5+ / o-series). + +Output style: +- Be terse and direct. Don't restate the user's request before answering. +- Don't begin with conversational openers ("Done!", "Got it", "Great question", "Sure thing"). Get to the answer or the action. +- Match response complexity to the task: simple questions → one-line answer; substantial work → lead with the outcome, then context, then any next steps. +- No nested bullets — keep lists flat (single level). For options the user can pick by replying with a number, use `1.` `2.` `3.`. +- Use inline backticks for paths/commands/identifiers; fenced code blocks (with language tags) for multi-line snippets. + +Channels (for clients that support them): +- `commentary` — short progress updates only when they add genuinely new information (a discovery, a tradeoff, a blocker, the start of a non-trivial step). Don't narrate routine reads or obvious next steps. +- `final` — the completed response. Keep it self-contained; no "see above" / "see below" cross-references. + +Tool calls: +- Parallelise independent tool calls in a single response (`multi_tool_use.parallel` where supported). Only sequence when a later call needs an earlier one's output. +- Don't ask permission ("Should I proceed?", "Do you want me to…?"). Pick the most reasonable default, do it, and state what you did. + +Autonomy: +- Persist until the task is fully resolved within the current turn whenever feasible. Don't stop at analysis when the user clearly wants the change applied. diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index fcd342d29..75342a8e1 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -976,7 +976,15 @@ class Document(BaseModel, TimestampMixin): document_metadata = Column(JSON, nullable=True) content = Column(Text, nullable=False) - content_hash = Column(String, nullable=False, index=True, unique=True) + # ``content_hash`` is intentionally NOT globally unique. In a real + # filesystem two files at different paths can hold identical bytes, + # and the agent's ``write_file`` flow needs that semantic to support + # copy / duplicate operations. Path uniqueness lives on + # ``unique_identifier_hash`` (per search space). The hash remains + # indexed because connector indexers consult it as a change-detection + # / cross-source dedup hint via :func:`check_duplicate_document`. + # See migration 133. + content_hash = Column(String, nullable=False, index=True) unique_identifier_hash = Column(String, nullable=True, index=True, unique=True) embedding = Column(Vector(config.embedding_model_instance.dimension)) diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index d35b7aa8b..d08bbc8cf 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -25,17 +25,33 @@ class TestProviderVariantDetection: @pytest.mark.parametrize( "model_name,expected", [ + # GPT-4 family routes to "classic" (autonomous-persistence style) ("openai:gpt-4o-mini", "openai_classic"), ("openai:gpt-4-turbo", "openai_classic"), + # GPT-5 / o-series route to "reasoning" (channel-aware pragmatic) ("openai:gpt-5", "openai_reasoning"), - ("openai:gpt-5-codex", "openai_reasoning"), ("openai:o1-preview", "openai_reasoning"), ("openai:o3-mini", "openai_reasoning"), + # Codex family beats reasoning (more specific). Mirrors OpenCode + # ``system.ts`` — ``gpt-*-codex`` gets the code-purist prompt. + ("openai:gpt-5-codex", "openai_codex"), + ("openai:gpt-codex", "openai_codex"), + ("openai:codex-mini", "openai_codex"), + # Anthropic + Google ("anthropic:claude-3-5-sonnet", "anthropic"), ("anthropic/claude-opus-4", "anthropic"), ("google:gemini-2.0-flash", "google"), ("vertex:gemini-1.5-pro", "google"), + # Newly-covered families + ("moonshot:kimi-k2", "kimi"), + ("openrouter:moonshot/kimi-k2.5", "kimi"), + ("xai:grok-2", "grok"), + ("openrouter:x-ai/grok-3", "grok"), + ("openai:deepseek-v3", "deepseek"), + ("deepseek:deepseek-r1", "deepseek"), + # Unknown families fall back to default (no provider block emitted) ("groq:mixtral-8x7b", "default"), + ("together:llama-3.1-70b", "default"), (None, "default"), ("", "default"), ], @@ -43,6 +59,16 @@ class TestProviderVariantDetection: def test_detection(self, model_name: str | None, expected: str) -> None: assert detect_provider_variant(model_name) == expected + def test_codex_takes_precedence_over_reasoning(self) -> None: + """Regression guard: ``gpt-5-codex`` must NOT match the generic + ``gpt-5`` reasoning regex first. Codex is the more specialised + prompt and mirrors OpenCode's dispatch order. + """ + from app.agents.new_chat.prompts.composer import detect_provider_variant + + assert detect_provider_variant("openai:gpt-5-codex") == "openai_codex" + assert detect_provider_variant("openai:gpt-5") == "openai_reasoning" + class TestCompose: def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None: @@ -149,6 +175,52 @@ class TestCompose: prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo") assert "" not in prompt + @pytest.mark.parametrize( + "model_name,expected_marker", + [ + # Each marker is a unique-ish phrase from the corresponding fragment. + # If a fragment is renamed/rewritten such that the marker is gone, + # update both the fragment and this test deliberately. + ("openai:gpt-5-codex", "Codex-class"), + ("openai:gpt-5", "OpenAI reasoning model"), + ("openai:gpt-4o", "classic OpenAI chat model"), + ("anthropic:claude-3-5-sonnet", "Anthropic Claude"), + ("google:gemini-2.0-flash", "Google Gemini"), + ("moonshot:kimi-k2", "Moonshot Kimi"), + ("xai:grok-2", "xAI Grok"), + ("deepseek:deepseek-r1", "DeepSeek"), + ], + ) + def test_each_known_variant_renders_with_its_marker( + self, + fixed_today: datetime, + model_name: str, + expected_marker: str, + ) -> None: + """Every supported variant must produce a ```` block + containing its identifying marker. This pins the dispatch + the + on-disk fragments together so a missing/renamed file is caught + immediately. + """ + prompt = compose_system_prompt(today=fixed_today, model_name=model_name) + assert "" in prompt, ( + f"variant for {model_name!r} did not emit a provider_hints block; " + "the corresponding providers/.md may be missing" + ) + assert expected_marker in prompt, ( + f"variant for {model_name!r} emitted hints but lacked the " + f"expected marker {expected_marker!r} — the fragment may have " + "drifted from the dispatch table" + ) + + def test_provider_blocks_are_byte_stable_across_calls( + self, fixed_today: datetime + ) -> None: + """Cache-stability guard: same model id → byte-identical prompt.""" + a = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2") + b = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2") + assert a == b + def test_custom_system_instructions_override_default( self, fixed_today: datetime ) -> None: diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py new file mode 100644 index 000000000..8b464d48d --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py @@ -0,0 +1,168 @@ +"""Unit tests for kb_persistence filesystem-parity invariants. + +Specifically, these tests pin down that the agent-driven write_file flow +treats path uniqueness — not content uniqueness — as the only hard +invariant. This mirrors a real filesystem: ``cp a b`` produces two files +with identical bytes living at different paths, and that should round-trip +through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from app.agents.new_chat.middleware import kb_persistence +from app.db import Document + + +class _FakeResult: + """Minimal stand-in for ``sqlalchemy.engine.Result``.""" + + def __init__(self, value: Any = None) -> None: + self._value = value + + def scalar_one_or_none(self) -> Any: + return self._value + + def scalar(self) -> Any: + return self._value + + +class _FakeSession: + """Minimal AsyncSession stand-in scoped to ``_create_document`` needs. + + Records every ``add`` so we can assert against the resulting Documents + and Chunks. ``execute`` always returns "no row" by default — i.e. no + folder hierarchy preexists and no path collision exists. Tests that + want a path collision can override that on a per-call basis. + """ + + def __init__(self) -> None: + self.added: list[Any] = [] + self.execute = AsyncMock(return_value=_FakeResult(None)) + self.flush = AsyncMock() + + # Simulate ``await session.flush()`` assigning an id to the doc; + # we increment a counter so each Document gets a unique id. + self._next_id = 1 + + async def _flush_assigning_ids() -> None: + for obj in self.added: + if getattr(obj, "id", None) is None: + obj.id = self._next_id + self._next_id += 1 + + self.flush.side_effect = _flush_assigning_ids + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: list[Any]) -> None: + self.added.extend(objs) + + +@pytest.fixture(autouse=True) +def _stub_embeddings_and_chunks(monkeypatch: pytest.MonkeyPatch) -> None: + """Avoid loading the embedding model in unit tests.""" + monkeypatch.setattr( + kb_persistence, + "embed_texts", + lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts], + ) + monkeypatch.setattr(kb_persistence, "chunk_text", lambda content: [content]) + + +@pytest.mark.asyncio +async def test_create_document_allows_identical_content_at_different_paths() -> None: + """The core regression: ``cp /a/notes.md /b/notes-copy.md``. + + Both create calls must succeed even though the bytes are byte-for-byte + identical, because path is the only filesystem-style unique key. + """ + session = _FakeSession() + content = "# Same body\n\nIdentical content used by two different paths.\n" + + first = await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/a/notes.md", + content=content, + search_space_id=42, + created_by_id="user-1", + ) + assert isinstance(first, Document) + assert first.title == "notes.md" + + # Second create with byte-identical content at a different path should + # not raise — that's the whole point of the filesystem-parity fix. + second = await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/b/notes-copy.md", + content=content, + search_space_id=42, + created_by_id="user-1", + ) + assert isinstance(second, Document) + assert second.title == "notes-copy.md" + + # Both rows share the same content_hash but live at distinct paths + # (distinct ``unique_identifier_hash``). That's the desired contract. + assert first.content_hash == second.content_hash + assert first.unique_identifier_hash != second.unique_identifier_hash + + +@pytest.mark.asyncio +async def test_create_document_still_rejects_path_collision() -> None: + """Path uniqueness remains the hard invariant. + + If ``unique_identifier_hash`` already points at an existing row in + the same search space, the create call must raise ``ValueError`` + with a clear message — matching the behavior the commit loop relies + on to upsert via the existing-row code path. + """ + session = _FakeSession() + + # Path with no folder parts so ``_ensure_folder_hierarchy`` is a + # no-op and the only SELECT executed is the path-collision check. + # That SELECT returns an existing doc id, triggering the guard. + session.execute = AsyncMock(return_value=_FakeResult(value=99)) + + with pytest.raises(ValueError, match="already exists at path"): + await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/notes.md", + content="anything", + search_space_id=42, + created_by_id="user-1", + ) + + +@pytest.mark.asyncio +async def test_create_document_does_not_query_for_content_hash_collision( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Regression guard: the legacy second SELECT (content_hash collision + pre-check) must be gone. Counting ``execute`` calls is a brittle but + effective way to lock that in. + + The current flow runs exactly one ``execute`` for the path-collision + SELECT (no folder parts in this path → ``_ensure_folder_hierarchy`` + short-circuits). If a future refactor reintroduces a content-hash + SELECT, this test will fail loud. + """ + session = _FakeSession() + await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/notes.md", + content="hello", + search_space_id=42, + created_by_id="user-1", + ) + # Path-collision SELECT only. No content_hash SELECT. + assert session.execute.await_count == 1, ( + f"Unexpected execute count {session.execute.await_count}; " + "did the legacy content_hash collision pre-check get re-added?" + ) diff --git a/surfsense_web/app/globals.css b/surfsense_web/app/globals.css index a37ddb8f3..f54bc2197 100644 --- a/surfsense_web/app/globals.css +++ b/surfsense_web/app/globals.css @@ -210,6 +210,27 @@ button { } } +/* Citation-jump highlight — entrance pulse only. The `SearchHighlightLeaf` + (see components/ui/search-highlight-node.tsx) is otherwise statically + tinted; this animation runs once on mount to draw the eye to the cited + text after `scrollIntoView` lands. The highlight itself is permanent + until the user clicks inside the editor (or another dismissal trigger + fires in `EditorPanelContent`). */ +@keyframes citation-flash-in { + 0% { + background-color: transparent; + box-shadow: 0 0 0 0 transparent; + } + 40% { + background-color: color-mix(in oklab, var(--primary) 30%, transparent); + box-shadow: 0 0 0 3px color-mix(in oklab, var(--primary) 25%, transparent); + } + 100% { + background-color: color-mix(in oklab, var(--primary) 15%, transparent); + box-shadow: 0 0 0 1px color-mix(in oklab, var(--primary) 40%, transparent); + } +} + /* Human-in-the-loop approval card animations */ @keyframes pulse-subtle { 0%, diff --git a/surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts b/surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts new file mode 100644 index 000000000..a3f8357e8 --- /dev/null +++ b/surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts @@ -0,0 +1,19 @@ +import { atom } from "jotai"; + +/** + * Cross-component handoff for citation jumps. Set by `InlineCitation` when a + * numeric chunk badge is clicked (after the document has been resolved); read + * by `DocumentTabContent` once the matching document tab mounts so it can + * scroll to and softly highlight the cited chunk inside the rendered markdown. + * + * Cleared by `DocumentTabContent` only after a terminal state — exact / + * approximate / miss — has been reached, so that an escalation refetch (2MB + * preview → 16MB) keeps the pending intent alive across the re-render. + */ +export interface PendingChunkHighlight { + documentId: number; + chunkId: number; + chunkText: string; +} + +export const pendingChunkHighlightAtom = atom(null); diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index eb4bd9af8..ae8d434a8 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -1,26 +1,45 @@ "use client"; -import { FileText } from "lucide-react"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useSetAtom } from "jotai"; +import { ExternalLink, FileText } from "lucide-react"; import type { FC } from "react"; -import { useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; +import { pendingChunkHighlightAtom } from "@/atoms/document-viewer/pending-chunk-highlight.atom"; +import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context"; -import { SourceDetailPanel } from "@/components/new-chat/source-detail-panel"; +import { MarkdownViewer } from "@/components/markdown-viewer"; import { Citation } from "@/components/tool-ui/citation"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Spinner } from "@/components/ui/spinner"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; interface InlineCitationProps { chunkId: number; isDocsChunk?: boolean; } +const POPOVER_HOVER_CLOSE_DELAY_MS = 150; + /** - * Inline citation for knowledge-base chunks (numeric chunk IDs). - * Renders a clickable badge showing the actual chunk ID that opens the SourceDetailPanel. - * Negative chunk IDs indicate anonymous/synthetic uploads and render as a static badge. + * Inline citation badge for knowledge-base chunks (numeric chunk IDs) and + * Surfsense documentation chunks (`isDocsChunk`). Negative chunk IDs render as + * a static "doc" pill (anonymous/synthetic uploads). + * + * Numeric KB chunks: clicking resolves the parent document via + * `getDocumentByChunk`, opens the document in the right side panel (alongside + * the chat — does not replace it), and stages the cited chunk text in + * `pendingChunkHighlightAtom` so `EditorPanelContent` can scroll to and softly + * highlight it inside the rendered markdown. + * + * Surfsense docs chunks: rendered as a hover-controlled shadcn Popover that + * lazily fetches and previews the cited chunk inline, since those docs aren't + * indexed into the user's search space and have no tab to open. */ export const InlineCitation: FC = ({ chunkId, isDocsChunk = false }) => { - const [isOpen, setIsOpen] = useState(false); - if (chunkId < 0) { return ( @@ -38,26 +57,185 @@ export const InlineCitation: FC = ({ chunkId, isDocsChunk = ); } + if (isDocsChunk) { + return ; + } + + return ; +}; + +const NumericChunkCitation: FC<{ chunkId: number }> = ({ chunkId }) => { + const queryClient = useQueryClient(); + const setPendingHighlight = useSetAtom(pendingChunkHighlightAtom); + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const [resolving, setResolving] = useState(false); + + const handleClick = useCallback(async () => { + if (resolving) return; + setResolving(true); + console.log("[citation:click] start", { chunkId }); + try { + const data = await queryClient.fetchQuery({ + // Local key with explicit window. The shared `cacheKeys.documents.byChunk` + // is window-agnostic (latent footgun); namespace the call to avoid + // reusing a different-window cached result. + queryKey: ["documents", "by-chunk", chunkId, "w0"] as const, + queryFn: () => + documentsApiService.getDocumentByChunk({ chunk_id: chunkId, chunk_window: 0 }), + staleTime: 5 * 60 * 1000, + }); + const cited = data.chunks.find((c) => c.id === chunkId) ?? data.chunks[0]; + console.log("[citation:click] fetched doc-by-chunk", { + docId: data.id, + docTitle: data.title, + chunksReturned: data.chunks.length, + citedChunkId: cited?.id, + citedChunkContentLen: cited?.content?.length ?? 0, + citedChunkPreview: + cited?.content && cited.content.length > 120 + ? `${cited.content.slice(0, 120)}…(+${cited.content.length - 120})` + : (cited?.content ?? ""), + }); + // Stage the highlight BEFORE opening the panel so `EditorPanelContent` + // already sees the pending intent on its very first render — avoids a + // "fetch → render → no-pending → next-tick render with pending" race. + setPendingHighlight({ + documentId: data.id, + chunkId, + chunkText: cited?.content ?? "", + }); + openEditorPanel({ + documentId: data.id, + searchSpaceId: data.search_space_id, + title: data.title, + }); + console.log("[citation:click] staged highlight + opened editor panel", { + documentId: data.id, + }); + } catch (err) { + console.warn("[citation:click] failed", err); + toast.error(err instanceof Error ? err.message : "Couldn't open cited document"); + } finally { + setResolving(false); + } + }, [chunkId, openEditorPanel, queryClient, resolving, setPendingHighlight]); + return ( - - + ); +}; + +const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => { + const [open, setOpen] = useState(false); + const closeTimerRef = useRef | null>(null); + + const cancelClose = useCallback(() => { + if (closeTimerRef.current) { + clearTimeout(closeTimerRef.current); + closeTimerRef.current = null; + } + }, []); + + const scheduleClose = useCallback(() => { + cancelClose(); + closeTimerRef.current = setTimeout(() => { + setOpen(false); + closeTimerRef.current = null; + }, POPOVER_HOVER_CLOSE_DELAY_MS); + }, [cancelClose]); + + useEffect(() => () => cancelClose(), [cancelClose]); + + const { data, isLoading, error } = useQuery({ + queryKey: cacheKeys.documents.byChunk(`doc-${chunkId}`), + queryFn: () => documentsApiService.getSurfsenseDocByChunk(chunkId), + enabled: open, + staleTime: 5 * 60 * 1000, + }); + + const citedChunk = data?.chunks.find((c) => c.id === chunkId) ?? data?.chunks[0]; + + return ( + + + + + e.preventDefault()} > - {chunkId} - - +
+
+

+ {data?.title ?? "Surfsense documentation"} +

+

Chunk #{chunkId}

+
+ {data?.source && ( + + + Open + + )} +
+
+ {isLoading && ( +
+ + Loading… +
+ )} + {error && ( +

+ {error instanceof Error ? error.message : "Failed to load chunk"} +

+ )} + {!isLoading && !error && citedChunk?.content && ( + + )} + {!isLoading && !error && !citedChunk?.content && ( +

No content available.

+ )} +
+ + ); }; diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 3b69ae6e0..0c4e9485b 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -1,5 +1,6 @@ "use client"; +import { FindReplacePlugin } from "@platejs/find-replace"; import { useAtomValue, useSetAtom } from "jotai"; import { Check, @@ -14,17 +15,21 @@ import { import dynamic from "next/dynamic"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; +import { pendingChunkHighlightAtom } from "@/atoms/document-viewer/pending-chunk-highlight.atom"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { VersionHistoryButton } from "@/components/documents/version-history"; +import type { PlateEditorInstance } from "@/components/editor/plate-editor"; import { SourceCodeEditor } from "@/components/editor/source-code-editor"; import { MarkdownViewer } from "@/components/markdown-viewer"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; +import { CITATION_HIGHLIGHT_CLASS } from "@/components/ui/search-highlight-node"; import { Spinner } from "@/components/ui/spinner"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; +import { buildCitationSearchCandidates } from "@/lib/citation-search"; import { inferMonacoLanguageFromPath } from "@/lib/editor-language"; const PlateEditor = dynamic( @@ -32,7 +37,10 @@ const PlateEditor = dynamic( { ssr: false, loading: () => } ); +type CitationHighlightStatus = "exact" | "miss"; + const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB +const CITATION_MAX_LENGTH = 16 * 1024 * 1024; // 16MB on-demand cap for citation jumps interface EditorContent { document_id: number; @@ -136,6 +144,61 @@ export function EditorPanelContent({ const [displayTitle, setDisplayTitle] = useState(title || "Untitled"); const isLocalFileMode = kind === "local_file"; const editorRenderMode: EditorRenderMode = isLocalFileMode ? "source_code" : "rich_markdown"; + + // --- Citation-jump highlight wiring ---------------------------------- + // `EditorPanelContent` is the consumer of `pendingChunkHighlightAtom`: when + // a citation badge is clicked, the badge stages `{documentId, chunkId, + // chunkText}` and opens this panel. We drive Plate's `FindReplacePlugin` + // (registered in every preset) to highlight the cited text natively via + // Slate decorations — no DOM walking, no Range gymnastics. The state + // machine below escalates the document fetch from 2MB → 16MB once if no + // candidate snippet matched in the preview, and surfaces miss outcomes + // via an inline alert. + const pending = useAtomValue(pendingChunkHighlightAtom); + const setPendingHighlight = useSetAtom(pendingChunkHighlightAtom); + const [fetchKey, setFetchKey] = useState(0); + const [maxLengthOverride, setMaxLengthOverride] = useState(null); + const [highlightResult, setHighlightResult] = useState(null); + const editorRef = useRef(null); + const escalatedForRef = useRef(null); + const lastAppliedChunkIdRef = useRef(null); + // Tracks whether a citation highlight is currently decorated in the + // editor. We use a ref (not state) because the click-to-dismiss handler + // runs in a stable callback that would otherwise close over stale state. + const isHighlightActiveRef = useRef(false); + // Once a citation jump targets this doc we have to keep `PlateEditor` + // mounted for the *rest of the doc session* — even after the highlight + // effect clears `pendingChunkHighlightAtom` (which it does as soon as + // the decoration is applied, so a follow-up citation on the same chunk + // can re-trigger). Without this latch, non-editable docs would re-render + // back into `MarkdownViewer` the instant `pending` is released, tearing + // down the Plate decorations and dropping the highlight after a frame. + const [stickyPlateMode, setStickyPlateMode] = useState(false); + + const clearCitationSearch = useCallback(() => { + isHighlightActiveRef.current = false; + const editor = editorRef.current; + if (!editor) return; + try { + editor.setOption(FindReplacePlugin, "search", ""); + editor.api.redecorate(); + } catch (err) { + console.warn("[EditorPanelContent] clearCitationSearch failed:", err); + } + }, []); + + // Dismiss the highlight when the user interacts with the editor surface. + // `onPointerDown` fires before focus / selection changes so the click + // itself feels responsive — the highlight clears in the same event tick + // that places the cursor. No-op when nothing is highlighted, so we don't + // thrash `redecorate` on every click in normal editing. + const handleEditorPointerDown = useCallback(() => { + if (!isHighlightActiveRef.current) return; + clearCitationSearch(); + setHighlightResult(null); + }, [clearCitationSearch]); + + const isCitationTarget = !!pending && !isLocalFileMode && pending.documentId === documentId; const resolveLocalVirtualPath = useCallback( async (candidatePath: string): Promise => { if (!electronAPI?.getAgentFilesystemMounts) { @@ -155,6 +218,8 @@ export function EditorPanelContent({ const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > LARGE_DOCUMENT_THRESHOLD; + // `fetchKey` is an explicit re-fetch trigger (escalation bumps it to force + // a new request even when documentId/searchSpaceId haven't changed). useEffect(() => { const controller = new AbortController(); setIsLoading(true); @@ -166,6 +231,12 @@ export function EditorPanelContent({ setIsEditing(false); initialLoadDone.current = false; changeCountRef.current = 0; + // Clear any in-flight FindReplacePlugin search before the editor + // re-mounts on new content (a fresh editor key is generated below + // from documentId + isEditing, so the previous editor + its + // decorations are about to be discarded anyway, but we belt-and- + // brace here for the case where only `fetchKey` changed). + clearCitationSearch(); const doFetch = async () => { try { @@ -210,7 +281,11 @@ export function EditorPanelContent({ const url = new URL( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); - url.searchParams.set("max_length", String(LARGE_DOCUMENT_THRESHOLD)); + url.searchParams.set("max_length", String(maxLengthOverride ?? LARGE_DOCUMENT_THRESHOLD)); + // `fetchKey` participates here so biome's noUnusedVariables sees it + // as consumed; bumping it forces a fresh request even when the URL + // is otherwise identical. + if (fetchKey > 0) url.searchParams.set("_n", String(fetchKey)); const response = await authenticatedFetch(url.toString(), { method: "GET" }); @@ -256,8 +331,259 @@ export function EditorPanelContent({ resolveLocalVirtualPath, searchSpaceId, title, + fetchKey, + maxLengthOverride, + clearCitationSearch, ]); + // Reset citation-jump bookkeeping whenever the panel switches to a different + // document (or local file). Body only writes setters — the deps are the + // real triggers we want to react to. + // biome-ignore lint/correctness/useExhaustiveDependencies: documentId/localFilePath are intentional triggers. + useEffect(() => { + clearCitationSearch(); + escalatedForRef.current = null; + lastAppliedChunkIdRef.current = null; + setHighlightResult(null); + setMaxLengthOverride(null); + setFetchKey(0); + // Drop sticky Plate mode when the panel moves to a different doc + // — the next doc starts in its preferred render mode (Plate for + // editable, MarkdownViewer for everything else) until/unless a + // citation jump targets it. + setStickyPlateMode(false); + }, [documentId, localFilePath, clearCitationSearch]); + + // Latch sticky Plate mode the first time a citation jump targets this + // doc. We keep it sticky for the remainder of this doc session so the + // highlight effect's `setPendingHighlight(null)` doesn't unmount the + // editor mid-flight (see comment on `stickyPlateMode` declaration). + useEffect(() => { + if (isCitationTarget) setStickyPlateMode(true); + }, [isCitationTarget]); + + // `isEditorReady` is what `useEffect` actually depends on — `editorRef` + // is a ref so changes don't trigger re-runs. We flip this to `true` once + // `PlateEditor` calls back with its live editor instance (its + // `usePlateEditor` value-init runs synchronously, so by the time this + // flips true the markdown is already deserialized into the Slate tree). + const [isEditorReady, setIsEditorReady] = useState(false); + const handleEditorReady = useCallback((editor: PlateEditorInstance | null) => { + console.log("[citation:editor] handleEditorReady", { ready: !!editor }); + editorRef.current = editor; + setIsEditorReady(!!editor); + }, []); + + // --- Citation jump highlight effect ----------------------------------- + // Drives Plate's FindReplacePlugin to highlight the cited chunk: + // 1. Build candidate snippets from the chunk text (first sentence, + // first 8 words, full chunk if short). Plate's decorate runs per- + // block and won't cross block boundaries, so the shorter + // candidates exist to give us something that fits in one + // paragraph / heading. + // 2. For each candidate: setOption('search', ...) → redecorate → + // wait two animation frames for React to flush → query the editor + // DOM for `.${CITATION_HIGHLIGHT_CLASS}`. First hit wins. + // + // Why a className and not a `data-*` attribute? Plate's + // `PlateLeaf` runs its props through `useNodeAttributes`, which + // only forwards `attributes`, `className`, `ref`, and `style` — + // arbitrary `data-*` attributes are silently dropped. `className` + // is the only escape hatch guaranteed to survive into the DOM. + // 3. On hit: smooth-scroll the first match into view, mark the + // highlight active (so a click inside the editor can dismiss it), + // release the pending atom. + // 4. On terminal miss: if the doc was truncated and we haven't + // escalated yet, bump the fetch's `max_length` to the citation + // cap and re-fetch — the post-refetch render will re-run this + // effect against the larger preview. Otherwise, release the + // atom and show the miss alert. + useEffect(() => { + console.log("[citation:effect] fired", { + isCitationTarget, + pendingDocId: pending?.documentId, + pendingChunkId: pending?.chunkId, + pendingChunkTextLen: pending?.chunkText?.length, + documentId, + isLocalFileMode, + isEditing, + hasMarkdown: !!editorDoc?.source_markdown, + markdownLen: editorDoc?.source_markdown?.length, + truncated: editorDoc?.truncated, + isEditorReady, + editorRefSet: !!editorRef.current, + maxLengthOverride, + }); + if (!isCitationTarget || !pending) { + console.log("[citation:effect] guard ✗ no citation target / no pending"); + return; + } + if (isLocalFileMode || isEditing) { + console.log("[citation:effect] guard ✗ localFileMode/editing"); + return; + } + if (!editorDoc?.source_markdown) { + console.log("[citation:effect] guard ✗ source_markdown not ready"); + return; + } + if (!isEditorReady) { + console.log("[citation:effect] guard ✗ editor not ready yet"); + return; + } + const editor = editorRef.current; + if (!editor) { + console.log("[citation:effect] guard ✗ editorRef.current is null"); + return; + } + + if (lastAppliedChunkIdRef.current !== pending.chunkId) { + lastAppliedChunkIdRef.current = pending.chunkId; + } + + let cancelled = false; + + const finishMiss = () => { + console.log("[citation:effect] terminal miss — no candidate matched"); + try { + editor.setOption(FindReplacePlugin, "search", ""); + editor.api.redecorate(); + } catch (err) { + console.warn("[EditorPanelContent] reset search after miss failed:", err); + } + const canEscalate = + editorDoc.truncated === true && + (maxLengthOverride ?? LARGE_DOCUMENT_THRESHOLD) < CITATION_MAX_LENGTH && + escalatedForRef.current !== pending.chunkId; + console.log("[citation:effect] miss decision", { + truncated: editorDoc.truncated, + currentMaxLength: maxLengthOverride ?? LARGE_DOCUMENT_THRESHOLD, + canEscalate, + }); + if (canEscalate) { + escalatedForRef.current = pending.chunkId; + setMaxLengthOverride(CITATION_MAX_LENGTH); + setFetchKey((k) => k + 1); + // Keep the atom set so the post-refetch render re-runs. + return; + } + setHighlightResult("miss"); + setPendingHighlight(null); + }; + + const tryCandidates = async () => { + const candidates = buildCitationSearchCandidates(pending.chunkText); + console.log("[citation:effect] candidates built", { + count: candidates.length, + previews: candidates.map((c) => c.slice(0, 60)), + }); + if (candidates.length === 0) { + if (!cancelled) finishMiss(); + return; + } + // Resolve the editor's rendered DOM root via Slate's stable + // `[data-slate-editor="true"]` attribute (set by slate-react's + // ``). Scoping queries to this root prevents + // `` elements rendered elsewhere on the page (e.g. chat + // search-highlight leaves in another mounted PlateEditor) from + // being mistaken for citation hits. + const editorRoot = document.querySelector('[data-slate-editor="true"]'); + console.log("[citation:effect] editor root", { + hasRoot: !!editorRoot, + }); + const root: ParentNode = editorRoot ?? document; + + for (let i = 0; i < candidates.length; i++) { + const candidate = candidates[i]; + if (cancelled) return; + try { + editor.setOption(FindReplacePlugin, "search", candidate); + editor.api.redecorate(); + console.log(`[citation:effect] try #${i} setOption + redecorate`, { + len: candidate.length, + preview: candidate.slice(0, 80), + }); + } catch (err) { + console.warn("[EditorPanelContent] setOption/redecorate failed:", err); + continue; + } + // Two rAFs: first lets Slate flush its onChange, second lets + // React commit the decoration leaves into the DOM. + await new Promise((resolve) => + requestAnimationFrame(() => requestAnimationFrame(() => resolve())) + ); + if (cancelled) return; + // Primary probe: by our stable class on the rendered . + let el = root.querySelector(`.${CITATION_HIGHLIGHT_CLASS}`); + const classMarkCount = root.querySelectorAll(`.${CITATION_HIGHLIGHT_CLASS}`).length; + // Diagnostic fallback: any inside the editor root. + // If we ever see allMarks > 0 but classMarkCount === 0, + // the className was stripped again and we need to revisit + // `useNodeAttributes` filtering. + const allMarkCount = root.querySelectorAll("mark").length; + if (!el && allMarkCount > 0) { + el = root.querySelector("mark"); + } + console.log(`[citation:effect] try #${i} DOM probe`, { + foundEl: !!el, + classMarkCount, + allMarkCount, + usedFallback: !!el && classMarkCount === 0, + }); + if (el) { + try { + el.scrollIntoView({ block: "center", behavior: "smooth" }); + } catch { + el.scrollIntoView(); + } + isHighlightActiveRef.current = true; + setHighlightResult("exact"); + console.log(`[citation:effect] ✓ exact via candidate #${i} — atom released`); + // No auto-clear timer — the highlight is intentionally + // permanent until the user clicks inside the editor (see + // `handleEditorPointerDown`) or another dismissal trigger + // fires (doc switch, edit-mode toggle, panel unmount, + // next citation jump). Sticky Plate mode keeps the + // editor mounted after the atom clears. + setPendingHighlight(null); + return; + } + } + if (!cancelled) finishMiss(); + }; + + void tryCandidates(); + + return () => { + cancelled = true; + }; + }, [ + isCitationTarget, + pending, + documentId, + editorDoc?.source_markdown, + editorDoc?.truncated, + isLocalFileMode, + isEditing, + isEditorReady, + maxLengthOverride, + clearCitationSearch, + setPendingHighlight, + ]); + + // Cleanup any active highlight on unmount. + useEffect(() => { + return () => clearCitationSearch(); + }, [clearCitationSearch]); + + // Toggling into edit mode swaps Plate out of readOnly. Clear the citation + // search so stale leaves don't linger in the editing surface. + useEffect(() => { + if (isEditing) { + clearCitationSearch(); + setHighlightResult(null); + } + }, [isEditing, clearCitationSearch]); + useEffect(() => { return () => { if (copyResetTimeoutRef.current) { @@ -367,6 +693,15 @@ export function EditorPanelContent({ EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && !isLargeDocument : false; + // Use PlateEditor for any of: + // - Editable doc types (FILE/NOTE) — existing editing UX. + // - Active citation jump in flight (`isCitationTarget`) — covers the + // mount in the very first render where the atom is set but the + // sticky effect hasn't fired yet. + // - Sticky Plate mode latched on a previous citation jump — keeps + // the editor mounted (with its decorations) after the highlight + // effect clears the atom. Resets when the doc changes. + const renderInPlateEditor = isEditableType || isCitationTarget || stickyPlateMode; const hasUnsavedChanges = editedMarkdown !== null; const showDesktopHeader = !!onClose; const showEditingActions = isEditableType && isEditing; @@ -381,6 +716,90 @@ export function EditorPanelContent({ setIsEditing(false); }, [editorDoc?.source_markdown]); + const handleDownloadMarkdown = useCallback(async () => { + if (!searchSpaceId || !documentId) return; + setDownloading(true); + try { + const response = await authenticatedFetch( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, + { method: "GET" } + ); + if (!response.ok) throw new Error("Download failed"); + const blob = await response.blob(); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + const disposition = response.headers.get("content-disposition"); + const match = disposition?.match(/filename="(.+)"/); + a.download = match?.[1] ?? `${editorDoc?.title || "document"}.md`; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(url); + toast.success("Download started"); + } catch { + toast.error("Failed to download document"); + } finally { + setDownloading(false); + } + }, [documentId, editorDoc?.title, searchSpaceId]); + + // We no longer surface an "approximate" status — Plate's FindReplacePlugin + // either decorates an exact match or it doesn't, and the candidate snippet + // strategy (first sentence → first 8 words → full chunk) means we either + // land on the citation start or fall through to the miss alert. + const showMissAlert = isCitationTarget && highlightResult === "miss"; + + const citationAlerts = showMissAlert && ( + + + + Cited section couldn't be located in this view. + {editorDoc?.truncated && ( + + )} + + + ); + + const largeDocAlert = isLargeDocument && !isLocalFileMode && editorDoc && ( + + + + + This document is too large for the editor ( + {Math.round((editorDoc.content_size_bytes ?? 0) / 1024 / 1024)}MB,{" "} + {editorDoc.chunk_count ?? 0} chunks). Showing a preview below. + + + + + ); + return ( <> {showDesktopHeader ? ( @@ -565,61 +984,6 @@ export function EditorPanelContent({

- ) : isLargeDocument && !isLocalFileMode ? ( -
- - - - - This document is too large for the editor ( - {Math.round((editorDoc.content_size_bytes ?? 0) / 1024 / 1024)}MB,{" "} - {editorDoc.chunk_count ?? 0} chunks). Showing a preview below. - - - - - -
) : editorRenderMode === "source_code" ? (
- ) : isEditableType ? ( - + ) : isLargeDocument && !isLocalFileMode && !isCitationTarget ? ( + // Large doc, no active citation — fast Streamdown preview + // + download CTA. We only fall back to MarkdownViewer here + // because Plate is heavy on multi-MB docs and the user + // isn't waiting on a specific citation to render. +
+ {largeDocAlert} + +
+ ) : renderInPlateEditor ? ( + // Editable doc (FILE/NOTE) OR active citation jump (any + // doc type). The citation path uses Plate's + // FindReplacePlugin for native, decoration-based + // highlighting — see the citation-jump highlight effect + // above for how `editorRef` and `handleEditorReady` are + // wired. +
+ {(citationAlerts || (isLargeDocument && isCitationTarget && !isLocalFileMode)) && ( +
+ {isLargeDocument && isCitationTarget && largeDocAlert} + {citationAlerts} +
+ )} +
+ +
+
) : (
diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 481a420fb..eef18ef6a 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -12,6 +12,12 @@ import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { Editor, EditorContainer } from "@/components/ui/editor"; +/** Live editor instance returned by `usePlateEditor`. Exposed via the + * `onEditorReady` prop so callers (e.g. `EditorPanelContent`) can drive + * plugin options imperatively — most notably setting + * `FindReplacePlugin`'s `search` option for citation-jump highlights. */ +export type PlateEditorInstance = ReturnType; + export interface PlateEditorProps { /** Markdown string to load as initial content */ markdown?: string; @@ -62,6 +68,15 @@ export interface PlateEditorProps { * without modifying the core editor component. */ extraPlugins?: AnyPluginConfig[]; + /** + * Called whenever the live editor instance (re)mounts, with `null` on + * unmount. Used by callers that need to drive plugin options imperatively + * — e.g. `EditorPanelContent` setting `FindReplacePlugin`'s `search` + * option for citation-jump highlights. The callback is invoked exactly + * once per editor lifetime (the parent's `key` prop forces a fresh + * editor when needed, e.g. on edit-mode toggle). + */ + onEditorReady?: (editor: PlateEditorInstance | null) => void; } function PlateEditorContent({ @@ -100,6 +115,7 @@ export function PlateEditor({ defaultEditing = false, preset = "full", extraPlugins = [], + onEditorReady, }: PlateEditorProps) { const lastMarkdownRef = useRef(markdown); const lastHtmlRef = useRef(html); @@ -156,6 +172,21 @@ export function PlateEditor({ : undefined, }); + // Expose the live editor instance to imperative callers (e.g. citation + // jump highlights). We deliberately don't depend on `onEditorReady` + // itself in the cleanup closure — callers commonly pass an arrow that + // closes over a stable ref setter, but if they pass a freshly-bound + // callback per render, the `onEditorReady?.(editor)` re-fires which is + // idempotent for ref-style setters. + const onEditorReadyRef = useRef(onEditorReady); + useEffect(() => { + onEditorReadyRef.current = onEditorReady; + }, [onEditorReady]); + useEffect(() => { + onEditorReadyRef.current?.(editor); + return () => onEditorReadyRef.current?.(null); + }, [editor]); + // Update editor content when html prop changes externally useEffect(() => { if (html !== undefined && html !== lastHtmlRef.current) { diff --git a/surfsense_web/components/editor/presets.ts b/surfsense_web/components/editor/presets.ts index c207b5e56..49f53ecf1 100644 --- a/surfsense_web/components/editor/presets.ts +++ b/surfsense_web/components/editor/presets.ts @@ -1,5 +1,6 @@ "use client"; +import { FindReplacePlugin } from "@platejs/find-replace"; import type { AnyPluginConfig } from "platejs"; import { TrailingBlockPlugin } from "platejs"; @@ -17,6 +18,30 @@ import { SelectionKit } from "@/components/editor/plugins/selection-kit"; import { SlashCommandKit } from "@/components/editor/plugins/slash-command-kit"; import { TableKit } from "@/components/editor/plugins/table-kit"; import { ToggleKit } from "@/components/editor/plugins/toggle-kit"; +import { SearchHighlightLeaf } from "@/components/ui/search-highlight-node"; + +/** + * Citation-jump highlighter. Re-uses Plate's built-in `FindReplacePlugin` + * (decorate-only, no editing surface) to drive the "scroll-to-cited-text" + * UX in `EditorPanelContent`. We register it in every preset because: + * - Decorate is a no-op when `search` is empty (single getOptions() check + * per block), so cost is effectively zero for non-citation viewers. + * - Keeping it preset-agnostic means citations work whether the doc is + * opened in editable (`full`) or pure-viewer (`readonly`) modes. + * + * The parent component drives `setOption(FindReplacePlugin, 'search', ...)` + * + `editor.api.redecorate()` to trigger highlights, then queries the + * editor DOM for `.citation-highlight-leaf` to scroll the first match + * into view. (We can't use a `data-*` attribute here — Plate's + * `PlateLeaf` runs props through `useNodeAttributes`, which only forwards + * `attributes`, `className`, `ref`, `style`; arbitrary `data-*` props are + * silently dropped.) See `components/ui/search-highlight-node.tsx` for + * the leaf component and `CITATION_HIGHLIGHT_CLASS` constant. + */ +const CitationFindReplacePlugin = FindReplacePlugin.configure({ + options: { search: "" }, + render: { node: SearchHighlightLeaf }, +}); /** * Full preset – every plugin kit enabled. @@ -38,6 +63,7 @@ export const fullPreset: AnyPluginConfig[] = [ ...AutoformatKit, ...DndKit, TrailingBlockPlugin, + CitationFindReplacePlugin, ]; /** @@ -52,6 +78,7 @@ export const minimalPreset: AnyPluginConfig[] = [ ...LinkKit, ...AutoformatKit, TrailingBlockPlugin, + CitationFindReplacePlugin, ]; /** @@ -68,6 +95,7 @@ export const readonlyPreset: AnyPluginConfig[] = [ ...CalloutKit, ...ToggleKit, ...MathKit, + CitationFindReplacePlugin, ]; /** All available preset names */ diff --git a/surfsense_web/components/new-chat/source-detail-panel.tsx b/surfsense_web/components/new-chat/source-detail-panel.tsx deleted file mode 100644 index aded206c7..000000000 --- a/surfsense_web/components/new-chat/source-detail-panel.tsx +++ /dev/null @@ -1,719 +0,0 @@ -"use client"; - -import { useQuery } from "@tanstack/react-query"; -import { - BookOpen, - ChevronDown, - ChevronUp, - ExternalLink, - FileQuestionMark, - FileText, - Hash, - Loader2, - Sparkles, - X, -} from "lucide-react"; -import { AnimatePresence, motion, useReducedMotion } from "motion/react"; -import { useTranslations } from "next-intl"; -import type React from "react"; -import { forwardRef, memo, type ReactNode, useCallback, useEffect, useRef, useState } from "react"; -import { createPortal } from "react-dom"; -import { MarkdownViewer } from "@/components/markdown-viewer"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { ScrollArea } from "@/components/ui/scroll-area"; -import { Spinner } from "@/components/ui/spinner"; -import type { - GetDocumentByChunkResponse, - GetSurfsenseDocsByChunkResponse, -} from "@/contracts/types/document.types"; -import { documentsApiService } from "@/lib/apis/documents-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { cn } from "@/lib/utils"; - -type DocumentData = GetDocumentByChunkResponse | GetSurfsenseDocsByChunkResponse; - -interface SourceDetailPanelProps { - open: boolean; - onOpenChange: (open: boolean) => void; - chunkId: number; - sourceType: string; - title: string; - description?: string; - url?: string; - children?: ReactNode; - isDocsChunk?: boolean; -} - -const formatDocumentType = (type: string) => { - if (!type) return ""; - return type - .split("_") - .map((word) => word.charAt(0) + word.slice(1).toLowerCase()) - .join(" "); -}; - -// Chunk card component -// For large documents (>30 chunks), we disable animation to prevent layout shifts -// which break auto-scroll functionality -interface ChunkCardProps { - chunk: { id: number; content: string }; - localIndex: number; - chunkNumber: number; - totalChunks: number; - isCited: boolean; - isActive: boolean; - disableLayoutAnimation?: boolean; -} - -const ChunkCard = memo( - forwardRef( - ({ chunk, localIndex, chunkNumber, totalChunks, isCited }, ref) => { - return ( -
- {isCited &&
} - -
-
-
- {chunkNumber} -
- - Chunk {chunkNumber} of {totalChunks} - -
- {isCited && ( - - - Cited Source - - )} -
- -
- -
-
- ); - } - ) -); -ChunkCard.displayName = "ChunkCard"; - -export function SourceDetailPanel({ - open, - onOpenChange, - chunkId, - sourceType, - title, - description, - url, - children, - isDocsChunk = false, -}: SourceDetailPanelProps) { - const t = useTranslations("dashboard"); - const scrollAreaRef = useRef(null); - const hasScrolledRef = useRef(false); // Use ref to avoid stale closures - const scrollTimersRef = useRef[]>([]); - const [activeChunkIndex, setActiveChunkIndex] = useState(null); - const [mounted, setMounted] = useState(false); - const shouldReduceMotion = useReducedMotion(); - - useEffect(() => { - setMounted(true); - }, []); - - const { - data: documentData, - isLoading: isDocumentByChunkFetching, - error: documentByChunkFetchingError, - } = useQuery({ - queryKey: isDocsChunk - ? cacheKeys.documents.byChunk(`doc-${chunkId}`) - : cacheKeys.documents.byChunk(chunkId.toString()), - queryFn: async () => { - if (isDocsChunk) { - return documentsApiService.getSurfsenseDocByChunk(chunkId); - } - return documentsApiService.getDocumentByChunk({ chunk_id: chunkId, chunk_window: 5 }); - }, - enabled: !!chunkId && open, - staleTime: 5 * 60 * 1000, - }); - - const totalChunks = - documentData && "total_chunks" in documentData - ? (documentData.total_chunks ?? documentData.chunks.length) - : (documentData?.chunks?.length ?? 0); - const [beforeChunks, setBeforeChunks] = useState< - Array<{ id: number; content: string; created_at: string }> - >([]); - const [afterChunks, setAfterChunks] = useState< - Array<{ id: number; content: string; created_at: string }> - >([]); - const [loadingBefore, setLoadingBefore] = useState(false); - const [loadingAfter, setLoadingAfter] = useState(false); - - useEffect(() => { - setBeforeChunks([]); - setAfterChunks([]); - }, [chunkId, open]); - - const chunkStartIndex = - documentData && "chunk_start_index" in documentData ? (documentData.chunk_start_index ?? 0) : 0; - const initialChunks = documentData?.chunks ?? []; - const allChunks = [...beforeChunks, ...initialChunks, ...afterChunks]; - const absoluteStart = chunkStartIndex - beforeChunks.length; - const absoluteEnd = chunkStartIndex + initialChunks.length + afterChunks.length; - const canLoadBefore = absoluteStart > 0; - const canLoadAfter = absoluteEnd < totalChunks; - - const EXPAND_SIZE = 10; - - const loadBefore = useCallback(async () => { - if (!documentData || !("search_space_id" in documentData) || !canLoadBefore) return; - setLoadingBefore(true); - try { - const count = Math.min(EXPAND_SIZE, absoluteStart); - const result = await documentsApiService.getDocumentChunks({ - document_id: documentData.id, - page: 0, - page_size: count, - start_offset: absoluteStart - count, - }); - const existingIds = new Set(allChunks.map((c) => c.id)); - const newChunks = result.items - .filter((c) => !existingIds.has(c.id)) - .map((c) => ({ id: c.id, content: c.content, created_at: c.created_at })); - setBeforeChunks((prev) => [...newChunks, ...prev]); - } catch (err) { - console.error("Failed to load earlier chunks:", err); - } finally { - setLoadingBefore(false); - } - }, [documentData, absoluteStart, canLoadBefore, allChunks]); - - const loadAfter = useCallback(async () => { - if (!documentData || !("search_space_id" in documentData) || !canLoadAfter) return; - setLoadingAfter(true); - try { - const result = await documentsApiService.getDocumentChunks({ - document_id: documentData.id, - page: 0, - page_size: EXPAND_SIZE, - start_offset: absoluteEnd, - }); - const existingIds = new Set(allChunks.map((c) => c.id)); - const newChunks = result.items - .filter((c) => !existingIds.has(c.id)) - .map((c) => ({ id: c.id, content: c.content, created_at: c.created_at })); - setAfterChunks((prev) => [...prev, ...newChunks]); - } catch (err) { - console.error("Failed to load later chunks:", err); - } finally { - setLoadingAfter(false); - } - }, [documentData, absoluteEnd, canLoadAfter, allChunks]); - - const isDirectRenderSource = - sourceType === "TAVILY_API" || - sourceType === "LINKUP_API" || - sourceType === "SEARXNG_API" || - sourceType === "BAIDU_SEARCH_API"; - - const citedChunkIndex = allChunks.findIndex((chunk) => chunk.id === chunkId); - - // Simple scroll function that scrolls to a chunk by index - const scrollToChunkByIndex = useCallback( - (chunkIndex: number, smooth = true) => { - const scrollContainer = scrollAreaRef.current; - if (!scrollContainer) return; - - const viewport = scrollContainer.querySelector( - "[data-radix-scroll-area-viewport]" - ) as HTMLElement | null; - if (!viewport) return; - - const chunkElement = scrollContainer.querySelector( - `[data-chunk-index="${chunkIndex}"]` - ) as HTMLElement | null; - if (!chunkElement) return; - - // Get positions using getBoundingClientRect for accuracy - const viewportRect = viewport.getBoundingClientRect(); - const chunkRect = chunkElement.getBoundingClientRect(); - - // Calculate where to scroll to center the chunk - const currentScrollTop = viewport.scrollTop; - const chunkTopRelativeToViewport = chunkRect.top - viewportRect.top + currentScrollTop; - const scrollTarget = - chunkTopRelativeToViewport - viewportRect.height / 2 + chunkRect.height / 2; - - viewport.scrollTo({ - top: Math.max(0, scrollTarget), - behavior: smooth && !shouldReduceMotion ? "smooth" : "auto", - }); - - setActiveChunkIndex(chunkIndex); - }, - [shouldReduceMotion] - ); - - // Callback ref for the cited chunk - scrolls when the element mounts - const citedChunkRefCallback = useCallback( - (node: HTMLDivElement | null) => { - if (node && !hasScrolledRef.current && open) { - hasScrolledRef.current = true; // Mark immediately to prevent duplicate scrolls - - // Store the node reference for the delayed scroll - const scrollToCitedChunk = () => { - const scrollContainer = scrollAreaRef.current; - if (!scrollContainer || !node.isConnected) return false; - - const viewport = scrollContainer.querySelector( - "[data-radix-scroll-area-viewport]" - ) as HTMLElement | null; - if (!viewport) return false; - - // Get positions - const viewportRect = viewport.getBoundingClientRect(); - const chunkRect = node.getBoundingClientRect(); - - // Calculate scroll position to center the chunk - const currentScrollTop = viewport.scrollTop; - const chunkTopRelativeToViewport = chunkRect.top - viewportRect.top + currentScrollTop; - const scrollTarget = - chunkTopRelativeToViewport - viewportRect.height / 2 + chunkRect.height / 2; - - viewport.scrollTo({ - top: Math.max(0, scrollTarget), - behavior: "auto", // Instant scroll for initial positioning - }); - - return true; - }; - - // Scroll multiple times with delays to handle progressive content rendering - // Each subsequent scroll will correct for any layout shifts - const scrollAttempts = [50, 150, 300, 600, 1000]; - - scrollAttempts.forEach((delay) => { - scrollTimersRef.current.push( - setTimeout(() => { - scrollToCitedChunk(); - }, delay) - ); - }); - - // After final attempt, mark the cited chunk as active - scrollTimersRef.current.push( - setTimeout( - () => { - setActiveChunkIndex(citedChunkIndex); - }, - scrollAttempts[scrollAttempts.length - 1] + 50 - ) - ); - } - }, - [open, citedChunkIndex] - ); - - // Reset scroll state when panel closes - useEffect(() => { - if (!open) { - scrollTimersRef.current.forEach(clearTimeout); - scrollTimersRef.current = []; - hasScrolledRef.current = false; - setActiveChunkIndex(null); - } - return () => { - scrollTimersRef.current.forEach(clearTimeout); - scrollTimersRef.current = []; - }; - }, [open]); - - // Handle escape key - useEffect(() => { - const handleEscape = (e: KeyboardEvent) => { - if (e.key === "Escape" && open) { - onOpenChange(false); - } - }; - window.addEventListener("keydown", handleEscape); - return () => window.removeEventListener("keydown", handleEscape); - }, [open, onOpenChange]); - - // Prevent body scroll when open - useEffect(() => { - if (open) { - document.body.style.overflow = "hidden"; - } else { - document.body.style.overflow = ""; - } - return () => { - document.body.style.overflow = ""; - }; - }, [open]); - - const handleUrlClick = (e: React.MouseEvent, clickUrl: string) => { - e.preventDefault(); - e.stopPropagation(); - window.open(clickUrl, "_blank", "noopener,noreferrer"); - }; - - const scrollToChunk = useCallback( - (index: number) => { - scrollToChunkByIndex(index, true); - }, - [scrollToChunkByIndex] - ); - - const panelContent = ( - - {open && ( - <> - {/* Backdrop */} - onOpenChange(false)} - /> - - {/* Panel */} - - {/* Header */} - -
-

- {documentData?.title || title || "Source Document"} -

-

- {documentData && "document_type" in documentData - ? formatDocumentType(documentData.document_type) - : sourceType && formatDocumentType(sourceType)} - {totalChunks > 0 && ( - - • {totalChunks} chunk{totalChunks !== 1 ? "s" : ""} - {allChunks.length < totalChunks && ` (showing ${allChunks.length})`} - - )} -

-
-
- {url && ( - - )} - -
-
- - {/* Loading State */} - {!isDirectRenderSource && isDocumentByChunkFetching && ( -
- - -

- {t("loading_document")} -

-
-
- )} - - {/* Error State */} - {!isDirectRenderSource && documentByChunkFetchingError && ( -
- -
- -
-
-

Document unavailable

-

- {documentByChunkFetchingError.message || - "An unexpected error occurred. Please try again."} -

-
- -
-
- )} - - {/* Direct render for web search providers */} - {isDirectRenderSource && ( - -
- {url && ( - - )} - -

- - Source Information -

-
- {title || "Untitled"} -
-
- {description || "No content available"} -
-
-
-
- )} - - {/* API-fetched document content */} - {!isDirectRenderSource && documentData && ( -
- {/* Chunk Navigation Sidebar */} - {allChunks.length > 1 && ( - - -
- {allChunks.map((chunk, idx) => { - const absNum = absoluteStart + idx + 1; - const isCited = chunk.id === chunkId; - const isActive = activeChunkIndex === idx; - return ( - scrollToChunk(idx)} - initial={{ opacity: 0, scale: 0.8 }} - animate={{ opacity: 1, scale: 1 }} - transition={{ delay: Math.min(idx * 0.02, 0.2) }} - className={cn( - "relative w-11 h-9 mx-auto rounded-lg text-xs font-semibold transition-all duration-200 flex items-center justify-center", - isCited - ? "bg-primary text-primary-foreground shadow-md" - : isActive - ? "bg-muted text-foreground" - : "bg-muted/50 text-muted-foreground hover:bg-muted hover:text-foreground" - )} - title={isCited ? `Chunk ${absNum} (Cited)` : `Chunk ${absNum}`} - > - {absNum} - {isCited && ( - - - - )} - - ); - })} -
-
-
- )} - - {/* Main Content */} - -
- {/* Document Metadata */} - {"document_metadata" in documentData && - documentData.document_metadata && - Object.keys(documentData.document_metadata).length > 0 && ( - -

- - Document Information -

-
- {Object.entries(documentData.document_metadata).map(([key, value]) => ( -
-
- {key.replace(/_/g, " ")} -
-
{String(value)}
-
- ))} -
-
- )} - - {/* Chunks Header */} -
-

- - Chunks {absoluteStart + 1}–{absoluteEnd} of {totalChunks} -

- {citedChunkIndex !== -1 && ( - - )} -
- - {/* Load Earlier */} - {canLoadBefore && ( -
- -
- )} - - {/* Chunks */} -
- {allChunks.map((chunk, idx) => { - const isCited = chunk.id === chunkId; - const chunkNumber = absoluteStart + idx + 1; - return ( - 30} - /> - ); - })} -
- - {/* Load Later */} - {canLoadAfter && ( -
- -
- )} -
-
-
- )} -
- - )} -
- ); - - if (!mounted) return <>{children}; - - return ( - <> - {children} - {createPortal(panelContent, globalThis.document.body)} - - ); -} diff --git a/surfsense_web/components/settings/user-settings-dialog.tsx b/surfsense_web/components/settings/user-settings-dialog.tsx index 7352a82ee..a04ce16dd 100644 --- a/surfsense_web/components/settings/user-settings-dialog.tsx +++ b/surfsense_web/components/settings/user-settings-dialog.tsx @@ -67,9 +67,6 @@ const DesktopShortcutsContent = dynamic( import( "@/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent" ).then((m) => ({ default: m.DesktopShortcutsContent })), - import( - "@/app/dashboard/[search_space_id]/user-settings/components/DesktopShortcutsContent" - ).then((m) => ({ default: m.DesktopShortcutsContent })), { ssr: false } ); const MemoryContent = dynamic( diff --git a/surfsense_web/components/ui/search-highlight-node.tsx b/surfsense_web/components/ui/search-highlight-node.tsx new file mode 100644 index 000000000..e3f316cce --- /dev/null +++ b/surfsense_web/components/ui/search-highlight-node.tsx @@ -0,0 +1,45 @@ +"use client"; + +import type { PlateLeafProps } from "platejs/react"; +import { PlateLeaf } from "platejs/react"; + +/** + * Stable class name used to identify Plate-rendered citation highlight + * leaves in the DOM. We can't use a `data-*` attribute here — Plate's + * `PlateLeaf` runs its props through `useNodeAttributes`, which only + * forwards `attributes`, `className`, `ref`, and `style` to the rendered + * element; arbitrary `data-*` props are silently dropped (verified + * against `@platejs/core/dist/react/index.js` v52). So `className` is + * the only escape hatch that's guaranteed to survive into the DOM. + */ +export const CITATION_HIGHLIGHT_CLASS = "citation-highlight-leaf"; + +/** + * Leaf rendered for ranges decorated by `@platejs/find-replace`'s + * `FindReplacePlugin`. We re-purpose that plugin to drive the citation-jump + * highlight: when a citation is staged, the parent sets the plugin's `search` + * option to a snippet of the chunk text and Plate decorates every match with + * `searchHighlight: true`. This component renders those decorations as a + * `` tagged with `CITATION_HIGHLIGHT_CLASS` so the parent can: + * 1. Query the first match in DOM order to scroll it into view. + * 2. Detect the active-highlight state without a separate React ref. + * + * The highlight is **persistent** — it does not auto-fade. The parent in + * `EditorPanelContent` clears it by setting the plugin's `search` option + * back to "" when one of: (a) the user clicks anywhere inside the editor, + * (b) the panel switches to a different document, (c) the user toggles + * into edit mode, (d) another citation jump is staged, (e) the panel + * unmounts. We use a brief entrance pulse (`citation-flash-in`, see + * `globals.css`) purely to draw the eye after `scrollIntoView` lands. + */ +export function SearchHighlightLeaf(props: PlateLeafProps) { + return ( + + {props.children} + + ); +} diff --git a/surfsense_web/lib/citation-search.ts b/surfsense_web/lib/citation-search.ts new file mode 100644 index 000000000..f80f13076 --- /dev/null +++ b/surfsense_web/lib/citation-search.ts @@ -0,0 +1,125 @@ +/** + * Snippet generation for the citation-jump highlight, driven by Plate's + * `FindReplacePlugin`. The plugin runs `decorate` per-block and only matches + * within blocks whose children are all `Text` nodes (so it crosses inline + * marks like bold/italic but **not** block boundaries, and a block that + * contains even one inline element such as a link is silently skipped). + * That means a full chunk that spans heading + paragraph won't match as a + * single string — we have to pick a shorter snippet that fits inside one + * rendered block. + * + * `buildCitationSearchCandidates` returns search strings ordered from + * "most-specific anchor" to "broadest fallback": + * 1. First sentence of the chunk (capped at `FIRST_SENTENCE_MAX`). + * 2. First `FIRST_PHRASE_WORDS` words. + * 3. Each non-trivial line of the chunk, in source order — gives us a + * separate attempt for each rendered block, so a heading line with + * an inline link doesn't doom the whole jump. + * 4. Full chunk (only if it's already short enough to plausibly fit + * inside one block). + * + * The caller tries each candidate in turn — set the plugin's `search` + * option, `editor.api.redecorate()`, then check the editor DOM for a + * `.citation-highlight-leaf` element. First candidate that produces one + * wins; subsequent candidates are skipped. + */ + +const FIRST_SENTENCE_MAX = 120; +const FIRST_PHRASE_WORDS = 8; +const MIN_SNIPPET_LENGTH = 6; +const FULL_CHUNK_MAX = FIRST_SENTENCE_MAX * 2; +const MAX_LINE_CANDIDATES = 6; +const LINE_CANDIDATE_MAX = FIRST_SENTENCE_MAX; + +function normalizeWhitespace(input: string): string { + return input.replace(/\s+/g, " ").trim(); +} + +/** + * Strip the markdown syntax that won't survive into the rendered editor's + * plain text, so the chunk text (which comes back from the indexer as raw + * source markdown) can be matched against the literal text values stored + * in Plate's Slate tree. + * + * Order matters: handle multi-char and "container" syntax before single- + * char emphasis, otherwise `**text**` collapses to `*text*` first. + * + * Heuristic only — we don't aim to be a full markdown parser, just to + * remove the common markers (`**bold**`, `[text](url)`, `# headings`, + * `- list`, etc.) that show up in connector-doc chunks and would break + * literal substring search. + */ +export function stripMarkdownForMatch(input: string): string { + let s = input; + s = s.replace(/```[a-z0-9_+-]*\n?([\s\S]*?)```/gi, (_, body: string) => body); + s = s.replace(//g, " "); + s = s.replace(/!\[([^\]]*)\]\([^)]*\)/g, "$1"); + s = s.replace(/!\[([^\]]*)\]\[[^\]]*\]/g, "$1"); + s = s.replace(/\[([^\]]+)\]\([^)]*\)/g, "$1"); + s = s.replace(/\[([^\]]+)\]\[[^\]]*\]/g, "$1"); + s = s.replace(/<((?:https?|mailto):[^>\s]+)>/g, "$1"); + s = s.replace(/`+([^`\n]+?)`+/g, "$1"); + s = s.replace(/(\*\*|__)([\s\S]+?)\1/g, "$2"); + s = s.replace(/(?+[ \t]?/gm, ""); + s = s.replace(/^[ \t]*[-*+][ \t]+/gm, ""); + s = s.replace(/^[ \t]*\d+\.[ \t]+/gm, ""); + s = s.replace(/^[ \t]{0,3}(?:[-*_])(?:[ \t]*[-*_]){2,}[ \t]*$/gm, ""); + s = s.replace(/^[ \t]*\|?(?:[ \t]*:?-+:?[ \t]*\|)+[ \t]*:?-+:?[ \t]*\|?[ \t]*$/gm, ""); + s = s.replace(/\\([\\`*_{}[\]()#+\-.!~>])/g, "$1"); + return s; +} + +export function buildCitationSearchCandidates(rawText: string): string[] { + if (!rawText) return []; + const stripped = stripMarkdownForMatch(rawText); + const normalized = normalizeWhitespace(stripped); + if (normalized.length < MIN_SNIPPET_LENGTH) return []; + + const out: string[] = []; + const seen = new Set(); + const push = (s: string) => { + const t = normalizeWhitespace(s); + if (t.length >= MIN_SNIPPET_LENGTH && !seen.has(t)) { + out.push(t); + seen.add(t); + } + }; + + const sentenceMatch = normalized.match(/^[^.!?]+[.!?]/); + if (sentenceMatch) { + const sentence = sentenceMatch[0]; + push(sentence.length > FIRST_SENTENCE_MAX ? sentence.slice(0, FIRST_SENTENCE_MAX) : sentence); + } else if (normalized.length > FIRST_SENTENCE_MAX) { + push(normalized.slice(0, FIRST_SENTENCE_MAX)); + } + + const words = normalized.split(" ").filter(Boolean); + if (words.length > FIRST_PHRASE_WORDS) { + push(words.slice(0, FIRST_PHRASE_WORDS).join(" ")); + } + + // Per-line candidates: each chunk line is roughly one block in the + // rendered editor. Trying them in order gives us a separate decorate + // attempt for each block, which matters when the first line is a + // heading containing a link (Plate's `FindReplacePlugin` will skip + // any block whose children aren't all text nodes). + const rawLines = stripped.split(/\r?\n/); + let lineCount = 0; + for (const line of rawLines) { + if (lineCount >= MAX_LINE_CANDIDATES) break; + const trimmed = normalizeWhitespace(line); + if (trimmed.length < MIN_SNIPPET_LENGTH) continue; + push(trimmed.length > LINE_CANDIDATE_MAX ? trimmed.slice(0, LINE_CANDIDATE_MAX) : trimmed); + lineCount++; + } + + if (normalized.length <= FULL_CHUNK_MAX) { + push(normalized); + } + + return out; +} diff --git a/surfsense_web/package.json b/surfsense_web/package.json index 41175daeb..665490e4f 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -36,6 +36,7 @@ "@platejs/code-block": "^52.0.11", "@platejs/combobox": "^52.0.15", "@platejs/dnd": "^52.0.11", + "@platejs/find-replace": "^52.3.10", "@platejs/floating": "^52.0.11", "@platejs/indent": "^52.0.11", "@platejs/link": "^52.0.11", diff --git a/surfsense_web/pnpm-lock.yaml b/surfsense_web/pnpm-lock.yaml index b1730e842..a1a7bea12 100644 --- a/surfsense_web/pnpm-lock.yaml +++ b/surfsense_web/pnpm-lock.yaml @@ -53,6 +53,9 @@ importers: '@platejs/dnd': specifier: ^52.0.11 version: 52.0.11(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dnd-html5-backend@16.0.1)(react-dnd@16.0.1(@types/node@20.19.33)(@types/react@19.2.14)(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@platejs/find-replace': + specifier: ^52.3.10 + version: 52.3.10(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@platejs/floating': specifier: ^52.0.11 version: 52.0.11(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -2827,6 +2830,13 @@ packages: react-dnd-html5-backend: '>=14.0.0' react-dom: '>=18.0.0' + '@platejs/find-replace@52.3.10': + resolution: {integrity: sha512-V/MOMMUYxHfEn/skd2+YO213xSATFDVsl8FzVzVRV/XaxwwVefH2EPD1lAVIvmYjennTVTTsHHtEI9K9iOsEaA==} + peerDependencies: + platejs: '>=52.0.11' + react: '>=18.0.0' + react-dom: '>=18.0.0' + '@platejs/floating@52.0.11': resolution: {integrity: sha512-ApNpw4KWml+kuK+XTTpji+f/7GxTR4nRzlnfJMvGBrJpLPQ4elS5MABm3oUi81DZn+aub5HvsyH7UqCw7F76IA==} peerDependencies: @@ -11105,6 +11115,13 @@ snapshots: react-dnd-html5-backend: 16.0.1 react-dom: 19.2.4(react@19.2.4) + '@platejs/find-replace@52.3.10(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + dependencies: + platejs: 52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)) + react: 19.2.4 + react-compiler-runtime: 1.0.0(react@19.2.4) + react-dom: 19.2.4(react@19.2.4) + '@platejs/floating@52.0.11(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: '@floating-ui/core': 1.7.4 From 4845b96209834badced8e201e45a4c89acc12ca4 Mon Sep 17 00:00:00 2001 From: yeranyang Date: Tue, 28 Apr 2026 12:16:27 +0800 Subject: [PATCH 219/299] perf(blog): derive search results with useMemo instead of useState+useEffect Fixes #1246 Replace the useState/useEffect pattern that synced fuzzy search results into local state on every search or searcher change with a single useMemo that derives results directly during render. Before: const [results, setResults] = useState(allBlogs); useEffect(() => { setResults(searcher.search(search)); }, [search, searcher]); After: const gridItems = useMemo(() => { const results = search.trim() ? searcher.search(search) : allBlogs; ... }, [search, searcher, allBlogs, featuredSlug]); This removes an extra re-render per keystroke and eliminates the stale intermediate state that occurred between the search input change and the effect firing. --- surfsense_web/app/(home)/blog/blog-magazine.tsx | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/surfsense_web/app/(home)/blog/blog-magazine.tsx b/surfsense_web/app/(home)/blog/blog-magazine.tsx index 96c7f6789..02e5045a9 100644 --- a/surfsense_web/app/(home)/blog/blog-magazine.tsx +++ b/surfsense_web/app/(home)/blog/blog-magazine.tsx @@ -3,7 +3,7 @@ import { format } from "date-fns"; import FuzzySearch from "fuzzy-search"; import Link from "next/link"; -import { useEffect, useMemo, useState } from "react"; +import { useMemo, useState } from "react"; import { Container } from "@/components/container"; import type { BlogEntry } from "./page"; @@ -127,17 +127,13 @@ function MagazineSearchGrid({ [allBlogs] ); - const [results, setResults] = useState(allBlogs); - useEffect(() => { - setResults(searcher.search(search)); - }, [search, searcher]); - const gridItems = useMemo(() => { + const results = search.trim() ? searcher.search(search) : allBlogs; if (search.trim()) { return results; } return results.filter((b) => b.slug !== featuredSlug); - }, [results, search, featuredSlug]); + }, [search, searcher, allBlogs, featuredSlug]); return (
From dcafa364ffad6337003108c992bf7253efda2cfa Mon Sep 17 00:00:00 2001 From: guangyang1206 Date: Wed, 29 Apr 2026 12:12:30 +0800 Subject: [PATCH 220/299] feat(perf): add loading.tsx skeletons for async marketing routes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #1243 Add sibling loading.tsx files for all 6 async route segments that were missing instant loading UI, causing blank screens during navigation on slow networks or cold caches. Routes covered: - /docs/[[...slug]] — awaits getDocPage + MDX body - /blog — awaits source.getPages() - /blog/[slug] — awaits params + MDX body - /changelog — awaits source.getPages() - /free — awaits getModels() fetch - /free/[model_slug] — awaits Promise.all([getModel, getAllModels]) Each loading.tsx is a Server Component returning an animate-pulse skeleton that matches its route's layout (header, content area, grid/table/timeline as appropriate). Uses the Skeleton component and Tailwind classes already present in the project. Follows the pattern established in: - app/dashboard/[search_space_id]/logs/loading.tsx - app/dashboard/[search_space_id]/new-chat/loading.tsx --- .../app/(home)/blog/[slug]/loading.tsx | 66 +++++++++++++++++++ surfsense_web/app/(home)/blog/loading.tsx | 50 ++++++++++++++ .../app/(home)/changelog/loading.tsx | 63 ++++++++++++++++++ .../app/(home)/free/[model_slug]/loading.tsx | 65 ++++++++++++++++++ surfsense_web/app/(home)/free/loading.tsx | 60 +++++++++++++++++ .../app/docs/[[...slug]]/loading.tsx | 55 ++++++++++++++++ 6 files changed, 359 insertions(+) create mode 100644 surfsense_web/app/(home)/blog/[slug]/loading.tsx create mode 100644 surfsense_web/app/(home)/blog/loading.tsx create mode 100644 surfsense_web/app/(home)/changelog/loading.tsx create mode 100644 surfsense_web/app/(home)/free/[model_slug]/loading.tsx create mode 100644 surfsense_web/app/(home)/free/loading.tsx create mode 100644 surfsense_web/app/docs/[[...slug]]/loading.tsx diff --git a/surfsense_web/app/(home)/blog/[slug]/loading.tsx b/surfsense_web/app/(home)/blog/[slug]/loading.tsx new file mode 100644 index 000000000..0cce7f80b --- /dev/null +++ b/surfsense_web/app/(home)/blog/[slug]/loading.tsx @@ -0,0 +1,66 @@ +import { Skeleton } from "@/components/ui/skeleton"; + +export default function BlogPostLoading() { + return ( +
+
+ {/* Breadcrumb */} +
+ + + + + +
+ + {/* Tags */} +
+ + +
+ + {/* Title */} +
+ + +
+ + {/* Description */} + + + + {/* Author + date */} +
+ +
+ + +
+
+ + {/* Cover image */} + + + {/* Article body paragraphs */} + {Array.from({ length: 5 }).map((_, i) => ( +
+ + + +
+ ))} + + {/* Sub-heading */} + + + {Array.from({ length: 3 }).map((_, i) => ( +
+ + + +
+ ))} +
+
+ ); +} diff --git a/surfsense_web/app/(home)/blog/loading.tsx b/surfsense_web/app/(home)/blog/loading.tsx new file mode 100644 index 000000000..ddaf345f6 --- /dev/null +++ b/surfsense_web/app/(home)/blog/loading.tsx @@ -0,0 +1,50 @@ +import { Skeleton } from "@/components/ui/skeleton"; + +export default function BlogIndexLoading() { + return ( +
+
+ {/* Header */} +
+ +
+ + {/* Featured post skeleton */} +
+ +
+ + + +
+ + + +
+
+
+ + {/* Search bar skeleton */} +
+ +
+ + {/* Grid of article cards */} +
+ {Array.from({ length: 6 }).map((_, i) => ( +
+ + + + +
+ + +
+
+ ))} +
+
+
+ ); +} diff --git a/surfsense_web/app/(home)/changelog/loading.tsx b/surfsense_web/app/(home)/changelog/loading.tsx new file mode 100644 index 000000000..648f5a5e6 --- /dev/null +++ b/surfsense_web/app/(home)/changelog/loading.tsx @@ -0,0 +1,63 @@ +import { Skeleton } from "@/components/ui/skeleton"; + +export default function ChangelogLoading() { + return ( +
+ {/* Header */} +
+
+
+
+ {/* Breadcrumb */} +
+ + + +
+ + +
+
+
+
+ + {/* Timeline */} +
+
+ {Array.from({ length: 3 }).map((_, i) => ( +
+ {/* Left: date + version */} +
+ + +
+ + {/* Right: content */} +
+
+ {/* Title */} + + {/* Tags */} +
+ + +
+ {/* Body paragraphs */} +
+ + + +
+
+ + +
+
+
+
+ ))} +
+
+
+ ); +} diff --git a/surfsense_web/app/(home)/free/[model_slug]/loading.tsx b/surfsense_web/app/(home)/free/[model_slug]/loading.tsx new file mode 100644 index 000000000..97660188d --- /dev/null +++ b/surfsense_web/app/(home)/free/[model_slug]/loading.tsx @@ -0,0 +1,65 @@ +import { Skeleton } from "@/components/ui/skeleton"; + +export default function FreeModelLoading() { + return ( + <> + {/* Chat area skeleton - fills viewport */} +
+ {/* Chat header */} +
+ + +
+ + {/* Chat messages area */} +
+
+ +
+
+ + + +
+
+ + {/* Input bar */} +
+ +
+
+ + {/* SEO section skeleton */} +
+
+ {/* Breadcrumb */} +
+ + + + + +
+ + + + + +
+ + {/* FAQ skeleton */} + +
+ {Array.from({ length: 4 }).map((_, i) => ( +
+ + + +
+ ))} +
+
+
+ + ); +} diff --git a/surfsense_web/app/(home)/free/loading.tsx b/surfsense_web/app/(home)/free/loading.tsx new file mode 100644 index 000000000..08a4ed6b6 --- /dev/null +++ b/surfsense_web/app/(home)/free/loading.tsx @@ -0,0 +1,60 @@ +import { Skeleton } from "@/components/ui/skeleton"; + +export default function FreeChatLoading() { + return ( +
+
+ {/* Breadcrumb */} +
+ + + +
+ + {/* Hero section */} +
+ + + + +
+ {Array.from({ length: 4 }).map((_, i) => ( + + ))} +
+
+ +
+ + {/* Model table */} +
+ + + +
+ {/* Table header */} +
+ + + + +
+ + {/* Table rows */} + {Array.from({ length: 8 }).map((_, i) => ( +
+
+ + +
+ + + +
+ ))} +
+
+
+
+ ); +} diff --git a/surfsense_web/app/docs/[[...slug]]/loading.tsx b/surfsense_web/app/docs/[[...slug]]/loading.tsx new file mode 100644 index 000000000..6bedcfc40 --- /dev/null +++ b/surfsense_web/app/docs/[[...slug]]/loading.tsx @@ -0,0 +1,55 @@ +import { Skeleton } from "@/components/ui/skeleton"; + +export default function DocsLoading() { + return ( +
+ {/* Title */} + + + {/* Description */} + + +
+ {/* Paragraph block 1 */} +
+ + + +
+ + {/* Sub-heading */} + + + {/* Paragraph block 2 */} +
+ + + + +
+ + {/* Code block placeholder */} + + + {/* Sub-heading */} + + + {/* List items */} +
+ {Array.from({ length: 4 }).map((_, i) => ( +
+ + +
+ ))} +
+ + {/* Paragraph block 3 */} +
+ + +
+
+
+ ); +} From 942077c7736c758bfcfca379e6eaf5902e1320c6 Mon Sep 17 00:00:00 2001 From: yeranyang Date: Tue, 28 Apr 2026 12:17:44 +0800 Subject: [PATCH 221/299] perf(docs): replace full lucide barrel import with explicit icon whitelist Fixes #1241 The docs bundle was importing `{ icons }` from lucide-react, which pulls the entire Lucide icon library (~1 400 SVGs, ~500 kB of JS) into the Next.js docs bundle even though only nine icons are used in docs frontmatter and meta.json files. Replace with a hand-maintained DOCS_ICONS whitelist that imports only the icons that are actually referenced (BookOpen, ClipboardCheck, Compass, Container, Download, FlaskConical, Heart, Unplug, Wrench). To add a new docs icon: import it from lucide-react and add it to the DOCS_ICONS record. The icon() callback remains the same for callers. --- surfsense_web/lib/source.ts | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/surfsense_web/lib/source.ts b/surfsense_web/lib/source.ts index 162cca57a..b94f990ab 100644 --- a/surfsense_web/lib/source.ts +++ b/surfsense_web/lib/source.ts @@ -1,12 +1,39 @@ import { loader } from "fumadocs-core/source"; -import { icons } from "lucide-react"; +import { + BookOpen, + ClipboardCheck, + Compass, + Container, + Download, + FlaskConical, + Heart, + Unplug, + Wrench, +} from "lucide-react"; import { createElement } from "react"; import { docs } from "@/.source/server"; +/** Explicit whitelist of Lucide icons used in docs frontmatter / meta.json. + * Importing the full `icons` barrel would pull every Lucide icon (~1 400 SVGs) + * into the docs bundle even though only a handful are referenced. Add new icons + * here as docs pages are added. + */ +const DOCS_ICONS: Record = { + BookOpen, + ClipboardCheck, + Compass, + Container, + Download, + FlaskConical, + Heart, + Unplug, + Wrench, +}; + export const source = loader({ baseUrl: "/docs", source: docs.toFumadocsSource(), icon(icon) { - if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]); + if (icon && icon in DOCS_ICONS) return createElement(DOCS_ICONS[icon]); }, }); From 345cb88224803a5fb0b18340882409a046e02ff9 Mon Sep 17 00:00:00 2001 From: guangyang1206 Date: Wed, 29 Apr 2026 12:14:08 +0800 Subject: [PATCH 222/299] refactor(settings): use key prop to reset LLM role manager form state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #1018 Remove the sync useEffect that copied preferences into local state, along with the savingRef guard that prevented mid-save overwrites. Instead, pass key={searchSpaceId} on the LLMRoleManager component so React remounts the form with correct initial state whenever the search space changes — no extra re-render, no effect dependency array. Changes: - llm-role-manager.tsx: remove useEffect + useRef + savingRef pattern; drop useEffect and useRef from imports (now only useCallback, useState) - search-space-settings-dialog.tsx: add key={searchSpaceId} to so the component remounts on search-space change Before: useEffect synced preferences → assignments on each preference update, with savingRef to avoid overwriting an in-flight save. After: React remounts the component with correct initial state from the preferences selector; no mid-save race possible. --- .../components/settings/llm-role-manager.tsx | 21 +------------------ .../settings/search-space-settings-dialog.tsx | 2 +- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index 015027111..e21dc9028 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -11,7 +11,7 @@ import { RefreshCw, ScanEye, } from "lucide-react"; -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useState } from "react"; import { toast } from "sonner"; import { globalImageGenConfigsAtom, @@ -143,23 +143,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { })); const [savingRole, setSavingRole] = useState(null); - const savingRef = useRef(false); - - useEffect(() => { - if (!savingRef.current) { - setAssignments({ - agent_llm_id: preferences.agent_llm_id ?? "", - document_summary_llm_id: preferences.document_summary_llm_id ?? "", - image_generation_config_id: preferences.image_generation_config_id ?? "", - vision_llm_config_id: preferences.vision_llm_config_id ?? "", - }); - } - }, [ - preferences?.agent_llm_id, - preferences?.document_summary_llm_id, - preferences?.image_generation_config_id, - preferences?.vision_llm_config_id, - ]); const handleRoleAssignment = useCallback( async (prefKey: string, configId: string) => { @@ -167,7 +150,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { setAssignments((prev) => ({ ...prev, [prefKey]: value })); setSavingRole(prefKey); - savingRef.current = true; try { await updatePreferences({ @@ -177,7 +159,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { toast.success("Role assignment updated"); } finally { setSavingRole(null); - savingRef.current = false; } }, [updatePreferences, searchSpaceId] diff --git a/surfsense_web/components/settings/search-space-settings-dialog.tsx b/surfsense_web/components/settings/search-space-settings-dialog.tsx index aefe1efd2..2a7ba82b6 100644 --- a/surfsense_web/components/settings/search-space-settings-dialog.tsx +++ b/surfsense_web/components/settings/search-space-settings-dialog.tsx @@ -116,7 +116,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings const content: Record = { general: , models: , - roles: , + roles: , "image-models": , "vision-models": , "team-roles": , From ca9bbee06dbd2e9e50be27f54a3967e20dfc0e7d Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Tue, 28 Apr 2026 21:37:51 -0700 Subject: [PATCH 223/299] chore: linting --- .../versions/130_add_agent_action_log.py | 4 +- .../133_drop_documents_content_hash_unique.py | 4 +- .../app/agents/new_chat/chat_deepagent.py | 19 ++++--- .../app/agents/new_chat/feature_flags.py | 20 +++++-- .../agents/new_chat/middleware/action_log.py | 4 +- .../agents/new_chat/middleware/compaction.py | 4 +- .../new_chat/middleware/context_editing.py | 7 ++- .../agents/new_chat/middleware/doom_loop.py | 21 ++++--- .../new_chat/middleware/knowledge_search.py | 11 ++-- .../new_chat/middleware/noop_injection.py | 12 ++-- .../agents/new_chat/middleware/otel_span.py | 14 ++--- .../agents/new_chat/middleware/permission.py | 41 ++++++++------ .../agents/new_chat/middleware/retry_after.py | 12 +++- .../new_chat/middleware/skills_backends.py | 17 ++++-- .../new_chat/middleware/tool_call_repair.py | 10 ++-- .../new_chat/plugins/year_substituter.py | 29 +++++----- .../app/agents/new_chat/prompts/composer.py | 8 +-- .../app/agents/new_chat/subagents/config.py | 4 +- .../app/agents/new_chat/tools/registry.py | 2 + surfsense_backend/app/observability/otel.py | 9 +-- .../app/routes/agent_flags_route.py | 2 +- .../app/routes/agent_permissions_route.py | 8 +-- .../app/routes/agent_revert_route.py | 6 +- .../app/routes/new_chat_routes.py | 4 +- .../app/services/revert_service.py | 4 +- .../app/utils/user_message_multimodal.py | 4 +- .../agents/new_chat/prompts/test_composer.py | 17 +++--- .../unit/agents/new_chat/test_action_log.py | 56 ++++++++++--------- .../unit/agents/new_chat/test_compaction.py | 20 +++++-- .../agents/new_chat/test_context_editing.py | 3 +- .../agents/new_chat/test_dedup_tool_calls.py | 18 +++++- .../test_default_permissions_layering.py | 8 +-- .../unit/agents/new_chat/test_doom_loop.py | 15 ++--- .../agents/new_chat/test_noop_injection.py | 8 ++- .../new_chat/test_permission_middleware.py | 4 +- .../agents/new_chat/test_plugin_loader.py | 12 ++-- .../unit/agents/new_chat/test_retry_after.py | 10 ++-- .../new_chat/test_specialized_subagents.py | 31 +++++----- .../agents/new_chat/test_tool_call_repair.py | 54 ++++++++++++------ .../test_kb_persistence_filesystem_parity.py | 2 +- .../unit/services/test_revert_service.py | 20 ++----- 41 files changed, 314 insertions(+), 244 deletions(-) diff --git a/surfsense_backend/alembic/versions/130_add_agent_action_log.py b/surfsense_backend/alembic/versions/130_add_agent_action_log.py index 5793988cb..2f06b8ddd 100644 --- a/surfsense_backend/alembic/versions/130_add_agent_action_log.py +++ b/surfsense_backend/alembic/versions/130_add_agent_action_log.py @@ -88,7 +88,5 @@ def upgrade() -> None: def downgrade() -> None: - op.drop_index( - "ix_agent_action_log_thread_created", table_name="agent_action_log" - ) + op.drop_index("ix_agent_action_log_thread_created", table_name="agent_action_log") op.drop_table("agent_action_log") diff --git a/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py b/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py index 88c3e203f..eec53ecb6 100644 --- a/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py +++ b/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py @@ -51,9 +51,7 @@ def upgrade() -> None: # implicit-unique-index variant SQLAlchemy may emit need draining. constraints = _existing_constraint_names(bind, "documents") if "uq_documents_content_hash" in constraints: - op.drop_constraint( - "uq_documents_content_hash", "documents", type_="unique" - ) + op.drop_constraint("uq_documents_content_hash", "documents", type_="unique") indexes = _existing_index_names(bind, "documents") # Some Postgres versions surface the unique constraint via a unique diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 672570696..3ca44dd4f 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -416,10 +416,10 @@ async def create_surfsense_deep_agent( # cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent`` # synchronously to compile the general-purpose subagent's full state graph # (every tool + every middleware → pydantic schemas + langgraph compile). - # On gpt-5.x agents that's roughly 1.5–2s of pure CPU work. If we run it + # On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it # directly here it blocks the asyncio event loop for the whole streaming # task (and any other coroutine sharing this loop), which is why - # "agent creation" wall-clock time used to stretch to ~3–4s. Move the + # "agent creation" wall-clock time used to stretch to ~3-4s. Move the # entire middleware build + main-graph compile into a single # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # event loop stays responsive. @@ -587,10 +587,7 @@ def _build_compiled_agent_blocking( # by name. Off by default until the flag flips so existing deployments # don't see new agent types in the task tool description. specialized_subagents: list[SubAgent] = [] - if ( - flags.enable_specialized_subagents - and not flags.disable_new_agent_stack - ): + if flags.enable_specialized_subagents and not flags.disable_new_agent_stack: try: # Specialized subagents share the parent's filesystem + # todo view so their system prompts (which promise @@ -696,7 +693,9 @@ def _build_compiled_agent_blocking( else None ) tool_call_limit_mw = ( - ToolCallLimitMiddleware(thread_limit=300, run_limit=80, exit_behavior="continue") + ToolCallLimitMiddleware( + thread_limit=300, run_limit=80, exit_behavior="continue" + ) if flags.enable_tool_call_limit and not flags.disable_new_agent_stack else None ) @@ -879,7 +878,11 @@ def _build_compiled_agent_blocking( max_tools=12, always_include=[ name - for name in ("update_memory", "get_connected_accounts", "scrape_webpage") + for name in ( + "update_memory", + "get_connected_accounts", + "scrape_webpage", + ) if name in {t.name for t in tools} ], ) diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index ce0a3b3fa..89c4fb14f 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -65,7 +65,9 @@ class AgentFeatureFlags: enable_model_call_limit: bool = False enable_tool_call_limit: bool = False enable_tool_call_repair: bool = False - enable_doom_loop: bool = False # Default OFF until UI handles permission='doom_loop' + enable_doom_loop: bool = ( + False # Default OFF until UI handles permission='doom_loop' + ) # Tier 2 — Safety enable_permission: bool = False # Default OFF for first deploy @@ -79,7 +81,9 @@ class AgentFeatureFlags: # Tier 5 — Snapshot / revert enable_action_log: bool = False - enable_revert_route: bool = False # Backend ships before UI; route returns 503 until this flips + enable_revert_route: bool = ( + False # Backend ships before UI; route returns 503 until this flips + ) # Tier 6 — Plugins enable_plugin_loader: bool = False @@ -109,14 +113,20 @@ class AgentFeatureFlags: enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), - enable_model_call_limit=_env_bool("SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False), + enable_model_call_limit=_env_bool( + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False + ), enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), - enable_tool_call_repair=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False), + enable_tool_call_repair=_env_bool( + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False + ), enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), # Tier 2 enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), - enable_llm_tool_selector=_env_bool("SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False), + enable_llm_tool_selector=_env_bool( + "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False + ), # Tier 4 enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), enable_specialized_subagents=_env_bool( diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/new_chat/middleware/action_log.py index cf0b57fd4..3675064e8 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/action_log.py +++ b/surfsense_backend/app/agents/new_chat/middleware/action_log.py @@ -101,9 +101,7 @@ class ActionLogMiddleware(AgentMiddleware): async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[ - [ToolCallRequest], Awaitable[ToolMessage | Command[Any]] - ], + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], ) -> ToolMessage | Command[Any]: if not self._enabled(): return await handler(request) diff --git a/surfsense_backend/app/agents/new_chat/middleware/compaction.py b/surfsense_backend/app/agents/new_chat/middleware/compaction.py index 8b02089c9..b0a1a7ec5 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/compaction.py +++ b/surfsense_backend/app/agents/new_chat/middleware/compaction.py @@ -177,8 +177,8 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware): messages_in=len(conversation_messages), extra={"compaction.cutoff_index": int(cutoff_index)}, ): - messages_to_summarize, preserved_messages = ( - super()._partition_messages(conversation_messages, cutoff_index) + messages_to_summarize, preserved_messages = super()._partition_messages( + conversation_messages, cutoff_index ) protected: list[AnyMessage] = [] diff --git a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py index 93ceab8ee..360e3e28f 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py +++ b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py @@ -58,8 +58,7 @@ DEFAULT_SPILL_PREFIX = "/tool_outputs" def _build_spill_placeholder(spill_path: str) -> str: """Build the user-facing placeholder text shown to the model.""" return ( - f"[cleared — full output at {spill_path}; " - f"ask the explore subagent to read it]" + f"[cleared — full output at {spill_path}; ask the explore subagent to read it]" ) @@ -131,7 +130,9 @@ class SpillToBackendEdit(ContextEdit): return candidates = [ - (idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage) + (idx, msg) + for idx, msg in enumerate(messages) + if isinstance(msg, ToolMessage) ] if self.keep >= len(candidates): return diff --git a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py index 49ac7dfa8..1dde87752 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py +++ b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py @@ -137,16 +137,21 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon triggered_call: dict[str, Any] | None = None for call in message.tool_calls: - name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None) - args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {}) + name = ( + call.get("name") + if isinstance(call, dict) + else getattr(call, "name", None) + ) + args = ( + call.get("args") + if isinstance(call, dict) + else getattr(call, "args", {}) + ) if not isinstance(name, str): continue sig = _signature(name, args) window.append(sig) - if ( - len(window) >= self._threshold - and len(set(window)) == 1 - ): + if len(window) >= self._threshold and len(set(window)) == 1: triggered_call = {"name": name, "params": args or {}} break @@ -209,7 +214,9 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon # tool call proceeds. The frontend's exact reply names may differ — # we tolerate any shape that contains a string with "reject"/"cancel". if isinstance(decision, dict): - kind = str(decision.get("decision_type") or decision.get("type") or "").lower() + kind = str( + decision.get("decision_type") or decision.get("type") or "" + ).lower() if "reject" in kind or "cancel" in kind: return {"jump_to": "end"} return None diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index f39870df6..08ca8e18b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -552,7 +552,7 @@ def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage: for entry in priority: score = entry.get("score") mentioned = entry.get("mentioned") - score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a" + score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a" mark = " [USER-MENTIONED]" if mentioned else "" lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}") body = "\n".join(lines) @@ -593,7 +593,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] # Tier 4.2: build the kb-planner private Runnable ONCE here so we - # don't pay the create_agent compile cost (50–200ms) on every turn. + # don't pay the create_agent compile cost (50-200ms) on every turn. # Disabled by default behind ``enable_kb_planner_runnable``; when off # the planner falls back to the legacy ``self.llm.ainvoke`` path. self._planner: Runnable | None = None @@ -617,10 +617,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] if self.llm is None: return None flags = get_flags() - if ( - not flags.enable_kb_planner_runnable - or flags.disable_new_agent_stack - ): + if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack: return None from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware @@ -920,7 +917,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] chunk_ids = doc.get("matched_chunk_ids") or [] if chunk_ids: matched_chunk_ids[doc_id] = [ - int(cid) for cid in chunk_ids if isinstance(cid, (int, str)) + int(cid) for cid in chunk_ids if isinstance(cid, int | str) ] return priority, matched_chunk_ids diff --git a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py index f16084892..8628479c7 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py +++ b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py @@ -35,9 +35,7 @@ from langchain_core.tools import tool logger = logging.getLogger(__name__) NOOP_TOOL_NAME = "_noop" -NOOP_TOOL_DESCRIPTION = ( - "Do not call this tool. It exists only for API compatibility." -) +NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility." @tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION) @@ -78,7 +76,9 @@ def _last_ai_has_tool_calls(messages: list[Any]) -> bool: return False -class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): +class NoopInjectionMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): """Inject the ``_noop`` tool only when the provider would otherwise 400. The check fires per model call, not at agent build time, because the @@ -116,7 +116,9 @@ class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R async def awrap_model_call( # type: ignore[override] self, request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], ) -> Any: if self._should_inject(request): logger.debug("Injecting _noop tool for provider compatibility") diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py index 5585cf7a2..f51d2f7bb 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py +++ b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py @@ -56,9 +56,7 @@ class OtelSpanMiddleware(AgentMiddleware): async def awrap_model_call( self, request: ModelRequest, - handler: Callable[ - [ModelRequest], Awaitable[ModelResponse | AIMessage | Any] - ], + handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]], ) -> ModelResponse | AIMessage | Any: if not ot.is_enabled(): return await handler(request) @@ -81,9 +79,7 @@ class OtelSpanMiddleware(AgentMiddleware): async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[ - [ToolCallRequest], Awaitable[ToolMessage | Command[Any]] - ], + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], ) -> ToolMessage | Command[Any]: if not ot.is_enabled(): return await handler(request) @@ -187,7 +183,11 @@ def _annotate_model_response(span: Any, result: Any) -> None: def _annotate_tool_result(span: Any, result: Any) -> None: try: if isinstance(result, ToolMessage): - content = result.content if isinstance(result.content, str) else repr(result.content) + content = ( + result.content + if isinstance(result.content, str) + else repr(result.content) + ) span.set_attribute("tool.output.size", len(content)) status = getattr(result, "status", None) if isinstance(status, str): diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py index f59e70bc0..6e1f42baf 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/permission.py +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -145,7 +145,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] try: patterns = resolver(args or {}) except Exception: - logger.exception("Pattern resolver for %s raised; using bare name", tool_name) + logger.exception( + "Pattern resolver for %s raised; using bare name", tool_name + ) patterns = [tool_name] if not patterns: patterns = [tool_name] @@ -198,11 +200,14 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] # Tier 3b: permission.asked + interrupt.raised spans (no-op when # OTel is disabled). Both fire here so dashboards can correlate # "we asked X" with "interrupt was actually delivered". - with ot.permission_asked_span( - permission=tool_name, - pattern=patterns[0] if patterns else None, - extra={"permission.patterns": list(patterns)}, - ), ot.interrupt_span(interrupt_type="permission_ask"): + with ( + ot.permission_asked_span( + permission=tool_name, + pattern=patterns[0] if patterns else None, + extra={"permission.patterns": list(patterns)}, + ), + ot.interrupt_span(interrupt_type="permission_ask"), + ): decision = interrupt(payload) if isinstance(decision, dict): return decision @@ -211,9 +216,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] return {"decision_type": decision} return {"decision_type": "reject"} - def _persist_always( - self, tool_name: str, patterns: list[str] - ) -> None: + def _persist_always(self, tool_name: str, patterns: list[str]) -> None: """Promote ``always`` reply into runtime allow rules. Persistence to ``agent_permission_rules`` is done by the @@ -276,12 +279,16 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] any_change = False for raw in last.tool_calls: - call = dict(raw) if isinstance(raw, dict) else { - "name": getattr(raw, "name", None), - "args": getattr(raw, "args", {}), - "id": getattr(raw, "id", None), - "type": "tool_call", - } + call = ( + dict(raw) + if isinstance(raw, dict) + else { + "name": getattr(raw, "name", None), + "args": getattr(raw, "args", {}), + "id": getattr(raw, "id", None), + "type": "tool_call", + } + ) name = call.get("name") or "" args = call.get("args") or {} action, patterns, rules = self._evaluate(name, args) @@ -307,7 +314,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] feedback = decision.get("feedback") if isinstance(feedback, str) and feedback.strip(): raise CorrectedError(feedback, tool=name) - raise RejectedError(tool=name, pattern=patterns[0] if patterns else None) + raise RejectedError( + tool=name, pattern=patterns[0] if patterns else None + ) else: logger.warning( "Unknown permission decision %r; treating as reject", kind diff --git a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py index 82da6a97c..394bb0371 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py +++ b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py @@ -113,7 +113,9 @@ def _exponential_delay( jitter: bool, ) -> float: """Compute an exponential-backoff delay with optional ±25% jitter.""" - delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay + delay = ( + initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay + ) delay = min(delay, max_delay) if jitter and delay > 0: delay *= 1 + random.uniform(-0.25, 0.25) @@ -201,7 +203,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp }, ) except Exception: - logger.debug("dispatch_custom_event failed; suppressed", exc_info=True) + logger.debug( + "dispatch_custom_event failed; suppressed", exc_info=True + ) if delay > 0: time.sleep(delay) # Unreachable @@ -210,7 +214,9 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp async def awrap_model_call( # type: ignore[override] self, request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], ) -> ModelResponse[ResponseT] | AIMessage: for attempt in range(self.max_retries + 1): try: diff --git a/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py b/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py index 4c3791c87..072d73401 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py +++ b/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py @@ -29,6 +29,7 @@ gives a clean failure mode if anything tries. from __future__ import annotations +import contextlib import logging from collections.abc import Callable from dataclasses import replace @@ -114,8 +115,10 @@ class BuiltinSkillsBackend(BackendProtocol): infos: list[FileInfo] = [] # Build virtual paths anchored at "/" because CompositeBackend already # stripped the route prefix before calling us. - target_virtual = "/" if target == self.root else ( - "/" + str(target.relative_to(self.root)).replace("\\", "/") + target_virtual = ( + "/" + if target == self.root + else ("/" + str(target.relative_to(self.root)).replace("\\", "/")) ) for child in sorted(target.iterdir()): child_virtual = ( @@ -128,10 +131,8 @@ class BuiltinSkillsBackend(BackendProtocol): "is_dir": child.is_dir(), } if child.is_file(): - try: + with contextlib.suppress(OSError): # pragma: no cover - defensive info["size"] = child.stat().st_size - except OSError: # pragma: no cover - defensive - pass infos.append(info) return infos @@ -163,7 +164,9 @@ class BuiltinSkillsBackend(BackendProtocol): else: content = target.read_bytes() except PermissionError: - responses.append(FileDownloadResponse(path=p, error="permission_denied")) + responses.append( + FileDownloadResponse(path=p, error="permission_denied") + ) continue except OSError as exc: # pragma: no cover - defensive logger.warning("Builtin skill read failed %s: %s", target, exc) @@ -286,6 +289,7 @@ def build_skills_backend_factory( builtin = BuiltinSkillsBackend(builtin_root) if search_space_id is None: + def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol: # Default StateBackend is intentionally inert: any path outside the # ``/skills/builtin/`` route resolves to an empty per-runtime state @@ -294,6 +298,7 @@ def build_skills_backend_factory( default=StateBackend(runtime), routes={SKILLS_BUILTIN_PREFIX: builtin}, ) + return _factory_builtin_only def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol: diff --git a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py index 6c3bc674d..54df0cc60 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py +++ b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py @@ -51,13 +51,15 @@ def _coerce_existing_tool_call(call: Any) -> dict[str, Any]: } -class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): +class ToolCallNameRepairMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): """Two-stage tool-name repair on the most recent ``AIMessage``. Args: registered_tool_names: Set of canonically-registered tool names. ``invalid`` should be in this set so the fallback dispatches. - fuzzy_match_threshold: Optional ``difflib`` ratio (0–1) for the + fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the fuzzy-match step that runs *between* lowercase and invalid. Set to ``None`` to disable fuzzy matching (opencode parity). """ @@ -77,9 +79,9 @@ class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], Contex def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]: """Allow runtime overrides to expand the set (e.g. dynamic MCP tools).""" ctx_tools = getattr(runtime.context, "registered_tool_names", None) - if isinstance(ctx_tools, (set, frozenset)): + if isinstance(ctx_tools, set | frozenset): return self._registered | set(ctx_tools) - if isinstance(ctx_tools, (list, tuple)): + if isinstance(ctx_tools, list | tuple): return self._registered | set(ctx_tools) return self._registered diff --git a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py index 927d533d5..3e2e631d2 100644 --- a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py +++ b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py @@ -52,25 +52,26 @@ class _YearSubstituterMiddleware(AgentMiddleware): async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[ - [ToolCallRequest], Awaitable[ToolMessage | Command[Any]] - ], + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], ) -> ToolMessage | Command[Any]: result = await handler(request) try: from langchain_core.messages import ToolMessage - if isinstance(result, ToolMessage) and isinstance(result.content, str): - if "{{year}}" in result.content: - new_text = result.content.replace("{{year}}", self._year) - result = ToolMessage( - content=new_text, - tool_call_id=result.tool_call_id, - id=result.id, - name=result.name, - status=result.status, - artifact=result.artifact, - ) + if ( + isinstance(result, ToolMessage) + and isinstance(result.content, str) + and "{{year}}" in result.content + ): + new_text = result.content.replace("{{year}}", self._year) + result = ToolMessage( + content=new_text, + tool_call_id=result.tool_call_id, + id=result.id, + name=result.name, + status=result.status, + artifact=result.artifact, + ) except Exception: # pragma: no cover - defensive logger.exception("year_substituter plugin failed; passing original result") return result diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/agents/new_chat/prompts/composer.py index bad033490..77b86aeef 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/composer.py +++ b/surfsense_backend/app/agents/new_chat/prompts/composer.py @@ -62,7 +62,9 @@ ProviderVariant = str # More specific patterns must come first (e.g. ``codex`` before # ``openai_reasoning`` because codex model ids contain ``gpt``). -_OPENAI_CODEX_RE = re.compile(r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE) +_OPENAI_CODEX_RE = re.compile( + r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE +) _OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE) _OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE) _ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE) @@ -257,9 +259,7 @@ def _build_tools_section( ) if known_disabled: disabled_list = ", ".join( - _format_tool_label(n) - for n in ALL_TOOL_NAMES_ORDERED - if n in known_disabled + _format_tool_label(n) for n in ALL_TOOL_NAMES_ORDERED if n in known_disabled ) parts.append( "\n" diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py index e20bc06bf..b36d35fa0 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/config.py +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -279,9 +279,7 @@ def build_explore_subagent( selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS) deny_rules = _read_only_deny_rules() - permission_mw = _build_permission_middleware( - deny_rules, origin="subagent_explore" - ) + permission_mw = _build_permission_middleware(deny_rules, origin="subagent_explore") spec: dict = { "name": "explore", diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index f5ee1a61d..fce1bf872 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -111,6 +111,8 @@ from .update_memory import create_update_memory_tool, create_update_team_memory_ from .video_presentation import create_generate_video_presentation_tool from .web_search import create_web_search_tool +logger = logging.getLogger(__name__) + # ============================================================================= # Tool Definition # ============================================================================= diff --git a/surfsense_backend/app/observability/otel.py b/surfsense_backend/app/observability/otel.py index 0229524f2..4f2257ab7 100644 --- a/surfsense_backend/app/observability/otel.py +++ b/surfsense_backend/app/observability/otel.py @@ -22,6 +22,7 @@ Goals from __future__ import annotations +import contextlib import logging import os from collections.abc import Iterator @@ -154,18 +155,14 @@ def span( with tracer.start_as_current_span(name) as sp: if attributes: - try: + with contextlib.suppress(Exception): # pragma: no cover — defensive sp.set_attributes(attributes) - except Exception: # pragma: no cover — defensive - pass try: yield sp except BaseException as exc: - try: + with contextlib.suppress(Exception): # pragma: no cover — defensive sp.record_exception(exc) sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc))) - except Exception: # pragma: no cover — defensive - pass raise diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index d3c90a58d..5732a8dfb 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -59,7 +59,7 @@ class AgentFeatureFlagsRead(BaseModel): enable_otel: bool @classmethod - def from_flags(cls, flags: AgentFeatureFlags) -> "AgentFeatureFlagsRead": + def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead: # asdict() avoids missing-field bugs when AgentFeatureFlags grows. return cls(**asdict(flags)) diff --git a/surfsense_backend/app/routes/agent_permissions_route.py b/surfsense_backend/app/routes/agent_permissions_route.py index e87af29c7..1c76e00e6 100644 --- a/surfsense_backend/app/routes/agent_permissions_route.py +++ b/surfsense_backend/app/routes/agent_permissions_route.py @@ -210,7 +210,7 @@ async def create_rule( session.add(row) try: await session.commit() - except IntegrityError: + except IntegrityError as err: await session.rollback() raise HTTPException( status_code=409, @@ -218,7 +218,7 @@ async def create_rule( "An identical rule already exists for this scope. Update the " "existing rule instead." ), - ) + ) from err await session.refresh(row) return _to_read(row) @@ -248,12 +248,12 @@ async def update_rule( try: await session.commit() - except IntegrityError: + except IntegrityError as err: await session.rollback() raise HTTPException( status_code=409, detail="Update would create a duplicate rule for this scope.", - ) + ) from err await session.refresh(row) return _to_read(row) diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index 2f6fe6a32..cbe4e7417 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -97,10 +97,12 @@ async def revert_agent_action( action=action, requester_user_id=str(user.id) if user is not None else None, ) - except Exception: + except Exception as err: logger.exception("Revert dispatch raised for action_id=%s", action_id) await session.rollback() - raise HTTPException(status_code=500, detail="Internal error during revert.") + raise HTTPException( + status_code=500, detail="Internal error during revert." + ) from err if outcome.status == "ok": await session.commit() diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index cbc660222..b5560d90d 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1242,7 +1242,9 @@ async def handle_new_chat( await session.close() image_urls = ( - [p.as_data_url() for p in request.user_images] if request.user_images else None + [p.as_data_url() for p in request.user_images] + if request.user_images + else None ) return StreamingResponse( diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index e072f90c6..f3630e0b4 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -79,9 +79,7 @@ async def load_action( return result.scalars().first() -async def load_thread( - session: AsyncSession, *, thread_id: int -) -> NewChatThread | None: +async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None: stmt = select(NewChatThread).where(NewChatThread.id == thread_id) result = await session.execute(stmt) return result.scalars().first() diff --git a/surfsense_backend/app/utils/user_message_multimodal.py b/surfsense_backend/app/utils/user_message_multimodal.py index 1d0691697..dc9a6fe76 100644 --- a/surfsense_backend/app/utils/user_message_multimodal.py +++ b/surfsense_backend/app/utils/user_message_multimodal.py @@ -7,7 +7,9 @@ import binascii from typing import Any -def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]: +def build_human_message_content( + final_query: str, image_data_urls: list[str] +) -> str | list[dict[str, Any]]: if not image_data_urls: return final_query parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}] diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index d08bbc8cf..aa0c215b9 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -90,9 +90,7 @@ class TestCompose: assert "" in prompt assert "[citation:chunk_id]" in prompt - def test_team_visibility_uses_team_variants( - self, fixed_today: datetime - ) -> None: + def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None: prompt = compose_system_prompt( today=fixed_today, thread_visibility=ChatVisibility.SEARCH_SPACE, @@ -145,9 +143,7 @@ class TestCompose: assert "Generate Image" in prompt assert "Generate Podcast" in prompt - def test_mcp_routing_block_emits_when_provided( - self, fixed_today: datetime - ) -> None: + def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None: prompt = compose_system_prompt( today=fixed_today, mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]}, @@ -162,9 +158,7 @@ class TestCompose: prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={}) assert "" not in prompt - def test_provider_block_renders_when_anthropic( - self, fixed_today: datetime - ) -> None: + def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None: prompt = compose_system_prompt( today=fixed_today, model_name="anthropic:claude-3-5-sonnet" ) @@ -267,7 +261,10 @@ class TestStableOrderingForCacheStability: ) b = compose_system_prompt( today=fixed_today, - enabled_tool_names={"scrape_webpage", "web_search"}, # set order shouldn't matter + enabled_tool_names={ + "scrape_webpage", + "web_search", + }, # set order shouldn't matter mcp_connector_tools={"X": ["x_a", "x_b"]}, ) assert a == b diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py index 6834b5be7..aad1524c9 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -83,7 +83,11 @@ class TestActionLogMiddlewareDisabled: async def test_no_op_when_flag_off(self, patch_get_flags) -> None: mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) request = _FakeRequest( - tool_call={"name": "make_widget", "args": {"color": "red", "size": 1}, "id": "tc1"} + tool_call={ + "name": "make_widget", + "args": {"color": "red", "size": 1}, + "id": "tc1", + } ) handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) with patch_get_flags(_disabled_flags()): @@ -117,13 +121,12 @@ class TestActionLogMiddlewarePersistence: "id": "tc-abc", }, ) - result_msg = ToolMessage( - content="ok", tool_call_id="tc-abc", id="msg-1" - ) + result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1") handler = AsyncMock(return_value=result_msg) - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=lambda: factory() + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), ): result = await mw.awrap_tool_call(request, handler) @@ -151,9 +154,11 @@ class TestActionLogMiddlewarePersistence: ) handler = AsyncMock(side_effect=ValueError("boom")) - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=lambda: factory() - ), pytest.raises(ValueError, match="boom"): + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + pytest.raises(ValueError, match="boom"), + ): await mw.awrap_tool_call(request, handler) assert len(captured["rows"]) == 1 @@ -177,8 +182,9 @@ class TestActionLogMiddlewarePersistence: def _exploding_session(): raise RuntimeError("DB is down") - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=_exploding_session + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=_exploding_session), ): result = await mw.awrap_tool_call(request, handler) assert result is result_msg @@ -218,8 +224,9 @@ class TestReverseDescriptor: ) handler = AsyncMock(return_value=result_msg) - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=lambda: factory() + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), ): await mw.awrap_tool_call(request, handler) @@ -257,8 +264,9 @@ class TestReverseDescriptor: result_msg = ToolMessage(content="ok", tool_call_id="tc1") handler = AsyncMock(return_value=result_msg) - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=lambda: factory() + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), ): await mw.awrap_tool_call(request, handler) @@ -275,11 +283,10 @@ class TestReverseDescriptor: request = _FakeRequest( tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"} ) - handler = AsyncMock( - return_value=ToolMessage(content="ok", tool_call_id="tc1") - ) - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=lambda: factory() + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), ): await mw.awrap_tool_call(request, handler) row = captured["rows"][0] @@ -298,11 +305,10 @@ class TestArgsTruncation: request = _FakeRequest( tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"}, ) - handler = AsyncMock( - return_value=ToolMessage(content="ok", tool_call_id="tc1") - ) - with patch_get_flags(_enabled_flags()), patch( - "app.db.shielded_async_session", side_effect=lambda: factory() + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), ): await mw.awrap_tool_call(request, handler) row = captured["rows"][0] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py index 4d8d6805c..c6d4cc452 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py @@ -26,10 +26,16 @@ class TestIsProtectedSystemMessage: assert _is_protected_system_message(msg) is True def test_unprotected_system_message(self) -> None: - assert _is_protected_system_message(SystemMessage(content="random instructions")) is False + assert ( + _is_protected_system_message(SystemMessage(content="random instructions")) + is False + ) def test_human_message_never_protected(self) -> None: - assert _is_protected_system_message(HumanMessage(content="...")) is False + assert ( + _is_protected_system_message(HumanMessage(content="...")) + is False + ) def test_tolerates_leading_whitespace(self) -> None: msg = SystemMessage(content=" \n\n...") @@ -97,11 +103,17 @@ class TestPartitionMessages: assert protected not in to_summary assert protected in preserved # The non-protected old messages remain in to_summary - assert any(isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary) + assert any( + isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary + ) def test_unprotected_messages_unaffected(self) -> None: partitioner = self._build_partitioner() - msgs = [HumanMessage(content="a"), HumanMessage(content="b"), HumanMessage(content="c")] + msgs = [ + HumanMessage(content="a"), + HumanMessage(content="b"), + HumanMessage(content="c"), + ] to_summary, preserved = partitioner._partition_messages(msgs, 2) assert [m.content for m in to_summary] == ["a", "b"] assert [m.content for m in preserved] == ["c"] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py index 3c31155d4..ba2246413 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py @@ -70,7 +70,8 @@ class TestSpillEdit: # Earlier ToolMessages should now contain the placeholder text cleared = [ - m for m in tool_messages + m + for m in tool_messages if isinstance(m.content, str) and m.content.startswith("[cleared") ] assert len(cleared) >= 1 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py index 95017d744..e04f50815 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py @@ -46,9 +46,21 @@ def test_callable_dedup_key_takes_priority() -> None: state = { "messages": [ _msg( - {"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "1"}, - {"name": "create_doc", "args": {"parent_id": "x", "title": "y"}, "id": "2"}, - {"name": "create_doc", "args": {"parent_id": "x", "title": "z"}, "id": "3"}, + { + "name": "create_doc", + "args": {"parent_id": "x", "title": "y"}, + "id": "1", + }, + { + "name": "create_doc", + "args": {"parent_id": "x", "title": "y"}, + "id": "2", + }, + { + "name": "create_doc", + "args": {"parent_id": "x", "title": "z"}, + "id": "3", + }, ) ] } diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py index d49edbfec..ac6b5d95c 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py @@ -84,9 +84,7 @@ class TestConnectorDenyOverridesDefaultAllow: Rule(permission="linear_create_issue", pattern="*", action="deny") ] ) - rules = evaluate_many( - "linear_create_issue", ["linear_create_issue"], *rulesets - ) + rules = evaluate_many("linear_create_issue", ["linear_create_issue"], *rulesets) assert aggregate_action(rules) == "deny" def test_default_allow_still_applies_to_other_tools(self) -> None: @@ -124,5 +122,7 @@ class TestUserRuleOverridesDefault: rules=[Rule(permission="send_*", pattern="*", action="deny")], origin="user", ) - rules = evaluate_many("send_gmail_email", ["send_gmail_email"], defaults, user_ruleset) + rules = evaluate_many( + "send_gmail_email", ["send_gmail_email"], defaults, user_ruleset + ) assert aggregate_action(rules) == "deny" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py index c54163dc3..802129bf6 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py @@ -64,22 +64,17 @@ def test_threshold_triggers_after_n_identical_calls() -> None: runtime, ) name = type(excinfo.value).__name__.lower() - assert ( - "interrupt" in name - or "runtimeerror" in name - ), f"Expected an interrupt-style exception, got {name}" + assert "interrupt" in name or "runtimeerror" in name, ( + f"Expected an interrupt-style exception, got {name}" + ) def test_does_not_trigger_when_args_differ() -> None: mw = DoomLoopMiddleware(threshold=2) runtime = _FakeRuntime() - out = mw.after_model( - {"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime - ) + out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime) assert out is None - out = mw.after_model( - {"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime - ) + out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime) assert out is None diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py index 8555eea76..346271f4b 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py @@ -91,7 +91,9 @@ class TestShouldInject: mw = NoopInjectionMiddleware() req = _FakeRequest( tools=[object()], - messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])], + messages=[ + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]) + ], model=_LiteLLMModel(), ) assert mw._should_inject(req) is False @@ -109,7 +111,9 @@ class TestShouldInject: mw = NoopInjectionMiddleware() req = _FakeRequest( tools=[], - messages=[AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}])], + messages=[ + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]) + ], model=_OpenAIModel(), ) assert mw._should_inject(req) is False diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py index 194a6eb27..a997c8d61 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py @@ -111,6 +111,4 @@ class TestAsk: assert out is None # call kept # Runtime ruleset got the always-allow rule new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] - assert any( - r.permission == "send_email" for r in new_rules - ) + assert any(r.permission == "send_email" for r in new_rules) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py index 8d98e1328..c2118c697 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py @@ -69,7 +69,9 @@ class TestPluginLoaderBasics: "app.agents.new_chat.plugin_loader.entry_points", return_value=[ep], ): - result = load_plugin_middlewares(_ctx(), allowed_plugin_names=["allowed_only"]) + result = load_plugin_middlewares( + _ctx(), allowed_plugin_names=["allowed_only"] + ) assert result == [] assert not called @@ -135,9 +137,7 @@ class TestPluginLoaderIsolation: _FakeEntryPoint("crashing", crashing_factory), _FakeEntryPoint("ok", year_substituter_factory), ] - with patch( - "app.agents.new_chat.plugin_loader.entry_points", return_value=eps - ): + with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps): result = load_plugin_middlewares( _ctx(), allowed_plugin_names={"crashing", "ok"} ) @@ -151,9 +151,7 @@ class TestAllowlistEnv: assert load_allowed_plugin_names_from_env() == set() def test_parses_comma_separated_value(self, monkeypatch) -> None: - monkeypatch.setenv( - "SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , " - ) + monkeypatch.setenv("SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , ") assert load_allowed_plugin_names_from_env() == { "year_substituter", "noisy", diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py index 39dd9bf00..d23fd693b 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py @@ -18,7 +18,7 @@ class _FakeResponse: self.headers = headers -class _FakeRateLimit(Exception): +class _FakeRateLimitError(Exception): def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None: super().__init__(msg) if headers is not None: @@ -27,15 +27,15 @@ class _FakeRateLimit(Exception): class TestExtractRetryAfter: def test_seconds_header(self) -> None: - exc = _FakeRateLimit("rate", {"Retry-After": "30"}) + exc = _FakeRateLimitError("rate", {"Retry-After": "30"}) assert _extract_retry_after_seconds(exc) == 30.0 def test_milliseconds_header_overrides_seconds(self) -> None: - exc = _FakeRateLimit("rate", {"retry-after-ms": "1500"}) + exc = _FakeRateLimitError("rate", {"retry-after-ms": "1500"}) assert _extract_retry_after_seconds(exc) == 1.5 def test_case_insensitive(self) -> None: - exc = _FakeRateLimit("rate", {"RETRY-AFTER": "12"}) + exc = _FakeRateLimitError("rate", {"RETRY-AFTER": "12"}) assert _extract_retry_after_seconds(exc) == 12.0 def test_falls_back_to_message_regex(self) -> None: @@ -67,7 +67,7 @@ class TestIsNonRetryable: class TestDelayCalculation: def test_takes_max_of_backoff_and_header(self) -> None: mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False) - exc = _FakeRateLimit("rl", {"retry-after": "10"}) + exc = _FakeRateLimitError("rl", {"retry-after": "10"}) delay = mw._delay_for_attempt(0, exc) assert delay == pytest.approx(10.0) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py index 3819b4605..0adb578ce 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py @@ -122,7 +122,9 @@ class TestExploreSubagent: def test_includes_permission_middleware_with_deny_rules(self) -> None: spec = build_explore_subagent(tools=ALL_TOOLS) permission_mws = [ - m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] + m + for m in spec["middleware"] + if isinstance(m, PermissionMiddleware) # type: ignore[index] ] assert len(permission_mws) == 1 ruleset = permission_mws[0]._static_rulesets[0] @@ -164,7 +166,9 @@ class TestReportWriterSubagent: def test_deny_rules_block_writes_but_allow_generate_report(self) -> None: spec = build_report_writer_subagent(tools=ALL_TOOLS) permission_mws = [ - m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] + m + for m in spec["middleware"] + if isinstance(m, PermissionMiddleware) # type: ignore[index] ] ruleset = permission_mws[0]._static_rulesets[0] deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} @@ -194,17 +198,15 @@ class TestConnectorNegotiatorSubagent: def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None: spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) permission_mws = [ - m for m in spec["middleware"] if isinstance(m, PermissionMiddleware) # type: ignore[index] + m + for m in spec["middleware"] + if isinstance(m, PermissionMiddleware) # type: ignore[index] ] ruleset = permission_mws[0]._static_rulesets[0] deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} # `linear_create_issue` matches the `*_create` deny pattern. - assert any( - _wildcard_matches(p, "linear_create_issue") for p in deny_patterns - ) - assert any( - _wildcard_matches(p, "slack_send_message") for p in deny_patterns - ) + assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns) + assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns) class TestBuildSpecializedSubagents: @@ -242,8 +244,7 @@ class TestBuildSpecializedSubagents: # order: extra → custom → patch → dedup. sentinel_idx = mws.index(sentinel) perm_idx = next( - (i for i, m in enumerate(mws) - if isinstance(m, PermissionMiddleware)), + (i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)), None, ) assert perm_idx is not None @@ -259,7 +260,9 @@ class TestFilterToolsWarningSuppression: from app.agents.new_chat.subagents.config import _filter_tools - with caplog.at_level(logging.INFO, logger="app.agents.new_chat.subagents.config"): + with caplog.at_level( + logging.INFO, logger="app.agents.new_chat.subagents.config" + ): # Allowed set asks for two registry tools (one present, one # not) plus a bunch of middleware-provided names. _filter_tools( @@ -275,9 +278,7 @@ class TestFilterToolsWarningSuppression: }, ) - warnings = [ - r.message for r in caplog.records if r.levelno >= logging.INFO - ] + warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO] # Exactly one warning, and it should mention scrape_webpage but not # any middleware-provided name. Inspect the rendered "missing" # list (between the brackets) so we don't false-match substrings diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py index f792aef60..e02a04774 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py @@ -27,9 +27,12 @@ class TestRepair: mw = ToolCallNameRepairMiddleware( registered_tool_names={"echo"}, fuzzy_match_threshold=None ) - msg = AIMessage(content="", tool_calls=[ - {"name": "echo", "args": {}, "id": "1"}, - ]) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "echo", "args": {}, "id": "1"}, + ], + ) out = mw.after_model(_make_state(msg), _FakeRuntime()) assert out is None # no change @@ -37,9 +40,12 @@ class TestRepair: mw = ToolCallNameRepairMiddleware( registered_tool_names={"echo"}, fuzzy_match_threshold=None ) - msg = AIMessage(content="", tool_calls=[ - {"name": "Echo", "args": {"x": 1}, "id": "1"}, - ]) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "Echo", "args": {"x": 1}, "id": "1"}, + ], + ) out = mw.after_model(_make_state(msg), _FakeRuntime()) assert out is not None repaired = out["messages"][0] @@ -50,9 +56,12 @@ class TestRepair: registered_tool_names={"echo", INVALID_TOOL_NAME}, fuzzy_match_threshold=None, ) - msg = AIMessage(content="", tool_calls=[ - {"name": "totally_different_name", "args": {"k": "v"}, "id": "1"}, - ]) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "totally_different_name", "args": {"k": "v"}, "id": "1"}, + ], + ) out = mw.after_model(_make_state(msg), _FakeRuntime()) assert out is not None repaired_call = out["messages"][0].tool_calls[0] @@ -64,9 +73,12 @@ class TestRepair: mw = ToolCallNameRepairMiddleware( registered_tool_names={"echo"}, fuzzy_match_threshold=None ) - msg = AIMessage(content="", tool_calls=[ - {"name": "unknown", "args": {}, "id": "1"}, - ]) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "unknown", "args": {}, "id": "1"}, + ], + ) out = mw.after_model(_make_state(msg), _FakeRuntime()) # No repair available; original returned unchanged (no update) assert out is None @@ -76,9 +88,12 @@ class TestRepair: registered_tool_names={"search_documents"}, fuzzy_match_threshold=0.7, ) - msg = AIMessage(content="", tool_calls=[ - {"name": "search_docments", "args": {}, "id": "1"}, - ]) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "search_docments", "args": {}, "id": "1"}, + ], + ) out = mw.after_model(_make_state(msg), _FakeRuntime()) assert out is not None assert out["messages"][0].tool_calls[0]["name"] == "search_documents" @@ -94,9 +109,12 @@ class TestRepair: mw = ToolCallNameRepairMiddleware( registered_tool_names={"echo"}, fuzzy_match_threshold=None ) - msg = AIMessage(content="", tool_calls=[ - {"name": "DynamicTool", "args": {}, "id": "1"}, - ]) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "DynamicTool", "args": {}, "id": "1"}, + ], + ) runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"])) out = mw.after_model(_make_state(msg), runtime) assert out is not None diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py index 8b464d48d..ef95434bf 100644 --- a/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py @@ -10,7 +10,7 @@ through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy. from __future__ import annotations from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock import numpy as np import pytest diff --git a/surfsense_backend/tests/unit/services/test_revert_service.py b/surfsense_backend/tests/unit/services/test_revert_service.py index cb8443291..e2cbe383a 100644 --- a/surfsense_backend/tests/unit/services/test_revert_service.py +++ b/surfsense_backend/tests/unit/services/test_revert_service.py @@ -16,9 +16,7 @@ class _FakeAction: class TestCanRevert: def test_owner_can_revert_their_own_action(self) -> None: action = _FakeAction(user_id="user-123") - assert can_revert( - requester_user_id="user-123", action=action, is_admin=False - ) + assert can_revert(requester_user_id="user-123", action=action, is_admin=False) def test_other_user_cannot_revert(self) -> None: action = _FakeAction(user_id="user-123") @@ -28,21 +26,15 @@ class TestCanRevert: def test_admin_always_allowed(self) -> None: action = _FakeAction(user_id="user-123") - assert can_revert( - requester_user_id="anybody", action=action, is_admin=True - ) + assert can_revert(requester_user_id="anybody", action=action, is_admin=True) def test_admin_can_revert_anonymous_action(self) -> None: action = _FakeAction(user_id=None) - assert can_revert( - requester_user_id="admin", action=action, is_admin=True - ) + assert can_revert(requester_user_id="admin", action=action, is_admin=True) def test_anonymous_action_blocks_non_admin(self) -> None: action = _FakeAction(user_id=None) - assert not can_revert( - requester_user_id="user-1", action=action, is_admin=False - ) + assert not can_revert(requester_user_id="user-1", action=action, is_admin=False) def test_uuid_string_normalization(self) -> None: """``user_id`` may be a UUID object; comparison should still work.""" @@ -51,6 +43,4 @@ class TestCanRevert: u = uuid.uuid4() action = _FakeAction(user_id=u) # Same UUID, passed as string from the requesting side. - assert can_revert( - requester_user_id=str(u), action=action, is_admin=False - ) + assert can_revert(requester_user_id=str(u), action=action, is_admin=False) From f23be16b351da11130c01d6157cc0fc817cb51b9 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Tue, 28 Apr 2026 23:25:26 -0700 Subject: [PATCH 224/299] refactor: citation viewer --- surfsense_web/app/globals.css | 21 - .../atoms/citation/citation-panel.atom.ts | 40 ++ .../pending-chunk-highlight.atom.ts | 19 - .../atoms/layout/right-panel.atom.ts | 2 +- .../assistant-ui/inline-citation.tsx | 80 +--- .../citation-panel/citation-panel.tsx | 230 ++++++++++ .../components/editor-panel/editor-panel.tsx | 407 +----------------- .../components/editor/plate-editor.tsx | 30 +- surfsense_web/components/editor/presets.ts | 28 -- .../layout/ui/right-panel/RightPanel.tsx | 72 +++- .../components/ui/search-highlight-node.tsx | 45 -- surfsense_web/lib/citation-search.ts | 125 ------ surfsense_web/package.json | 1 - surfsense_web/pnpm-lock.yaml | 17 - 14 files changed, 362 insertions(+), 755 deletions(-) create mode 100644 surfsense_web/atoms/citation/citation-panel.atom.ts delete mode 100644 surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts create mode 100644 surfsense_web/components/citation-panel/citation-panel.tsx delete mode 100644 surfsense_web/components/ui/search-highlight-node.tsx delete mode 100644 surfsense_web/lib/citation-search.ts diff --git a/surfsense_web/app/globals.css b/surfsense_web/app/globals.css index f54bc2197..a37ddb8f3 100644 --- a/surfsense_web/app/globals.css +++ b/surfsense_web/app/globals.css @@ -210,27 +210,6 @@ button { } } -/* Citation-jump highlight — entrance pulse only. The `SearchHighlightLeaf` - (see components/ui/search-highlight-node.tsx) is otherwise statically - tinted; this animation runs once on mount to draw the eye to the cited - text after `scrollIntoView` lands. The highlight itself is permanent - until the user clicks inside the editor (or another dismissal trigger - fires in `EditorPanelContent`). */ -@keyframes citation-flash-in { - 0% { - background-color: transparent; - box-shadow: 0 0 0 0 transparent; - } - 40% { - background-color: color-mix(in oklab, var(--primary) 30%, transparent); - box-shadow: 0 0 0 3px color-mix(in oklab, var(--primary) 25%, transparent); - } - 100% { - background-color: color-mix(in oklab, var(--primary) 15%, transparent); - box-shadow: 0 0 0 1px color-mix(in oklab, var(--primary) 40%, transparent); - } -} - /* Human-in-the-loop approval card animations */ @keyframes pulse-subtle { 0%, diff --git a/surfsense_web/atoms/citation/citation-panel.atom.ts b/surfsense_web/atoms/citation/citation-panel.atom.ts new file mode 100644 index 000000000..ca7312857 --- /dev/null +++ b/surfsense_web/atoms/citation/citation-panel.atom.ts @@ -0,0 +1,40 @@ +import { atom } from "jotai"; +import { rightPanelCollapsedAtom, rightPanelTabAtom } from "@/atoms/layout/right-panel.atom"; + +interface CitationPanelState { + isOpen: boolean; + chunkId: number | null; +} + +const initialState: CitationPanelState = { + isOpen: false, + chunkId: null, +}; + +export const citationPanelAtom = atom(initialState); + +export const citationPanelOpenAtom = atom((get) => get(citationPanelAtom).isOpen); + +const preCitationCollapsedAtom = atom(null); + +export const openCitationPanelAtom = atom(null, (get, set, payload: { chunkId: number }) => { + if (!get(citationPanelAtom).isOpen) { + set(preCitationCollapsedAtom, get(rightPanelCollapsedAtom)); + } + set(citationPanelAtom, { + isOpen: true, + chunkId: payload.chunkId, + }); + set(rightPanelTabAtom, "citation"); + set(rightPanelCollapsedAtom, false); +}); + +export const closeCitationPanelAtom = atom(null, (get, set) => { + set(citationPanelAtom, initialState); + set(rightPanelTabAtom, "sources"); + const prev = get(preCitationCollapsedAtom); + if (prev !== null) { + set(rightPanelCollapsedAtom, prev); + set(preCitationCollapsedAtom, null); + } +}); diff --git a/surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts b/surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts deleted file mode 100644 index a3f8357e8..000000000 --- a/surfsense_web/atoms/document-viewer/pending-chunk-highlight.atom.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { atom } from "jotai"; - -/** - * Cross-component handoff for citation jumps. Set by `InlineCitation` when a - * numeric chunk badge is clicked (after the document has been resolved); read - * by `DocumentTabContent` once the matching document tab mounts so it can - * scroll to and softly highlight the cited chunk inside the rendered markdown. - * - * Cleared by `DocumentTabContent` only after a terminal state — exact / - * approximate / miss — has been reached, so that an escalation refetch (2MB - * preview → 16MB) keeps the pending intent alive across the re-render. - */ -export interface PendingChunkHighlight { - documentId: number; - chunkId: number; - chunkText: string; -} - -export const pendingChunkHighlightAtom = atom(null); diff --git a/surfsense_web/atoms/layout/right-panel.atom.ts b/surfsense_web/atoms/layout/right-panel.atom.ts index e06500113..d296587ed 100644 --- a/surfsense_web/atoms/layout/right-panel.atom.ts +++ b/surfsense_web/atoms/layout/right-panel.atom.ts @@ -1,6 +1,6 @@ import { atom } from "jotai"; -export type RightPanelTab = "sources" | "report" | "editor" | "hitl-edit"; +export type RightPanelTab = "sources" | "report" | "editor" | "hitl-edit" | "citation"; export const rightPanelTabAtom = atom("sources"); diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index ae8d434a8..2aeba89ca 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -1,13 +1,11 @@ "use client"; -import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useQuery } from "@tanstack/react-query"; import { useSetAtom } from "jotai"; import { ExternalLink, FileText } from "lucide-react"; import type { FC } from "react"; import { useCallback, useEffect, useRef, useState } from "react"; -import { toast } from "sonner"; -import { pendingChunkHighlightAtom } from "@/atoms/document-viewer/pending-chunk-highlight.atom"; -import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; +import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownViewer } from "@/components/markdown-viewer"; import { Citation } from "@/components/tool-ui/citation"; @@ -29,11 +27,11 @@ const POPOVER_HOVER_CLOSE_DELAY_MS = 150; * Surfsense documentation chunks (`isDocsChunk`). Negative chunk IDs render as * a static "doc" pill (anonymous/synthetic uploads). * - * Numeric KB chunks: clicking resolves the parent document via - * `getDocumentByChunk`, opens the document in the right side panel (alongside - * the chat — does not replace it), and stages the cited chunk text in - * `pendingChunkHighlightAtom` so `EditorPanelContent` can scroll to and softly - * highlight it inside the rendered markdown. + * Numeric KB chunks: clicking opens the citation panel in the right + * sidebar (alongside the chat — does not replace it). The panel shows + * the cited chunk surrounded by adjacent chunks (via the API's + * `chunk_window`), with the cited one highlighted and an option to + * expand the window or jump into the full document via the editor panel. * * Surfsense docs chunks: rendered as a hover-controlled shadcn Popover that * lazily fetches and previews the cited chunk inline, since those docs aren't @@ -65,71 +63,17 @@ export const InlineCitation: FC = ({ chunkId, isDocsChunk = }; const NumericChunkCitation: FC<{ chunkId: number }> = ({ chunkId }) => { - const queryClient = useQueryClient(); - const setPendingHighlight = useSetAtom(pendingChunkHighlightAtom); - const openEditorPanel = useSetAtom(openEditorPanelAtom); - const [resolving, setResolving] = useState(false); - - const handleClick = useCallback(async () => { - if (resolving) return; - setResolving(true); - console.log("[citation:click] start", { chunkId }); - try { - const data = await queryClient.fetchQuery({ - // Local key with explicit window. The shared `cacheKeys.documents.byChunk` - // is window-agnostic (latent footgun); namespace the call to avoid - // reusing a different-window cached result. - queryKey: ["documents", "by-chunk", chunkId, "w0"] as const, - queryFn: () => - documentsApiService.getDocumentByChunk({ chunk_id: chunkId, chunk_window: 0 }), - staleTime: 5 * 60 * 1000, - }); - const cited = data.chunks.find((c) => c.id === chunkId) ?? data.chunks[0]; - console.log("[citation:click] fetched doc-by-chunk", { - docId: data.id, - docTitle: data.title, - chunksReturned: data.chunks.length, - citedChunkId: cited?.id, - citedChunkContentLen: cited?.content?.length ?? 0, - citedChunkPreview: - cited?.content && cited.content.length > 120 - ? `${cited.content.slice(0, 120)}…(+${cited.content.length - 120})` - : (cited?.content ?? ""), - }); - // Stage the highlight BEFORE opening the panel so `EditorPanelContent` - // already sees the pending intent on its very first render — avoids a - // "fetch → render → no-pending → next-tick render with pending" race. - setPendingHighlight({ - documentId: data.id, - chunkId, - chunkText: cited?.content ?? "", - }); - openEditorPanel({ - documentId: data.id, - searchSpaceId: data.search_space_id, - title: data.title, - }); - console.log("[citation:click] staged highlight + opened editor panel", { - documentId: data.id, - }); - } catch (err) { - console.warn("[citation:click] failed", err); - toast.error(err instanceof Error ? err.message : "Couldn't open cited document"); - } finally { - setResolving(false); - } - }, [chunkId, openEditorPanel, queryClient, resolving, setPendingHighlight]); + const openCitationPanel = useSetAtom(openCitationPanelAtom); return ( ); }; diff --git a/surfsense_web/components/citation-panel/citation-panel.tsx b/surfsense_web/components/citation-panel/citation-panel.tsx new file mode 100644 index 000000000..cec07b9cf --- /dev/null +++ b/surfsense_web/components/citation-panel/citation-panel.tsx @@ -0,0 +1,230 @@ +"use client"; + +import { useQuery } from "@tanstack/react-query"; +import { useSetAtom } from "jotai"; +import { ChevronDown, ChevronUp, ExternalLink, XIcon } from "lucide-react"; +import type { FC } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; +import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; +import { MarkdownViewer } from "@/components/markdown-viewer"; +import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; + +const DEFAULT_CHUNK_WINDOW = 5; +const EXPANDED_CHUNK_WINDOW = 50; + +interface CitationPanelContentProps { + chunkId: number; + onClose?: () => void; +} + +/** + * Right-panel citation viewer. Shows the cited chunk surrounded by + * adjacent chunks (±N chunks via the API's `chunk_window` parameter), + * with the cited one visually highlighted and auto-scrolled into view. + * The window can be expanded to a wider range, or the user can jump to + * the full document via the editor panel. + */ +export const CitationPanelContent: FC = ({ chunkId, onClose }) => { + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const [expanded, setExpanded] = useState(false); + + useEffect(() => { + setExpanded(false); + }, []); + + const chunkWindow = expanded ? EXPANDED_CHUNK_WINDOW : DEFAULT_CHUNK_WINDOW; + + const { data, isLoading, error } = useQuery({ + queryKey: ["citation-panel", chunkId, chunkWindow] as const, + queryFn: () => + documentsApiService.getDocumentByChunk({ + chunk_id: chunkId, + chunk_window: chunkWindow, + }), + staleTime: 5 * 60 * 1000, + }); + + const cited = useMemo(() => data?.chunks.find((c) => c.id === chunkId) ?? null, [data, chunkId]); + + const totalChunks = data?.total_chunks ?? data?.chunks.length ?? 0; + const startIndex = data?.chunk_start_index ?? 0; + const citedIndexInWindow = data + ? Math.max( + 0, + data.chunks.findIndex((c) => c.id === chunkId) + ) + : 0; + const shownAbove = citedIndexInWindow; + const shownBelow = data ? Math.max(0, data.chunks.length - 1 - citedIndexInWindow) : 0; + const hasMoreAbove = startIndex > 0; + const hasMoreBelow = data ? startIndex + data.chunks.length < totalChunks : false; + + // Scroll the cited chunk into view inside the panel's scroll container + // (not the page). We anchor the scroll to the panel's scroll element + // so opening the citation doesn't yank the chat scroll on the left. + const scrollContainerRef = useRef(null); + const citedRef = useRef(null); + useEffect(() => { + if (!cited) return; + const id = requestAnimationFrame(() => { + const container = scrollContainerRef.current; + const target = citedRef.current; + if (!container || !target) return; + const containerRect = container.getBoundingClientRect(); + const targetRect = target.getBoundingClientRect(); + const offset = targetRect.top - containerRect.top + container.scrollTop; + container.scrollTo({ + top: Math.max(0, offset - 16), + behavior: "smooth", + }); + }); + return () => cancelAnimationFrame(id); + }, [cited]); + + const handleOpenFullDocument = () => { + if (!data) return; + openEditorPanel({ + documentId: data.id, + searchSpaceId: data.search_space_id, + title: data.title, + }); + }; + + return ( + <> +
+
+

Citation

+
+ {onClose && ( + + )} +
+
+
+
+

+ {data?.title ?? (isLoading ? "Loading…" : `Chunk #${chunkId}`)} +

+
+
+ Chunk #{chunkId} + {totalChunks > 0 && · {totalChunks} chunks} +
+
+
+ +
+ {isLoading && ( +
+ + Loading citation… +
+ )} + + {error && ( +

+ {error instanceof Error ? error.message : "Failed to load citation"} +

+ )} + + {!isLoading && !error && data && ( + <> + {hasMoreAbove && ( +

+ … {startIndex} earlier chunk{startIndex === 1 ? "" : "s"} not shown +

+ )} +
+ {data.chunks.map((chunk) => { + const isCited = chunk.id === chunkId; + return ( +
+
+ + {isCited ? "Cited chunk" : `Chunk #${chunk.id}`} + + {isCited && ( + #{chunk.id} + )} +
+
+ +
+
+ ); + })} +
+ {hasMoreBelow && ( +

+ … {totalChunks - (startIndex + data.chunks.length)} later chunk + {totalChunks - (startIndex + data.chunks.length) === 1 ? "" : "s"} not shown +

+ )} + + )} +
+ + {!isLoading && !error && data && ( +
+
+ Showing {shownAbove} above · cited · {shownBelow} below +
+
+ {(hasMoreAbove || hasMoreBelow) && !expanded && ( + + )} + {expanded && ( + + )} + +
+
+ )} + + ); +}; diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 0c4e9485b..df138e97e 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -1,6 +1,5 @@ "use client"; -import { FindReplacePlugin } from "@platejs/find-replace"; import { useAtomValue, useSetAtom } from "jotai"; import { Check, @@ -15,21 +14,17 @@ import { import dynamic from "next/dynamic"; import { useCallback, useEffect, useRef, useState } from "react"; import { toast } from "sonner"; -import { pendingChunkHighlightAtom } from "@/atoms/document-viewer/pending-chunk-highlight.atom"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { VersionHistoryButton } from "@/components/documents/version-history"; -import type { PlateEditorInstance } from "@/components/editor/plate-editor"; import { SourceCodeEditor } from "@/components/editor/source-code-editor"; import { MarkdownViewer } from "@/components/markdown-viewer"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; -import { CITATION_HIGHLIGHT_CLASS } from "@/components/ui/search-highlight-node"; import { Spinner } from "@/components/ui/spinner"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { buildCitationSearchCandidates } from "@/lib/citation-search"; import { inferMonacoLanguageFromPath } from "@/lib/editor-language"; const PlateEditor = dynamic( @@ -37,10 +32,7 @@ const PlateEditor = dynamic( { ssr: false, loading: () => } ); -type CitationHighlightStatus = "exact" | "miss"; - const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB -const CITATION_MAX_LENGTH = 16 * 1024 * 1024; // 16MB on-demand cap for citation jumps interface EditorContent { document_id: number; @@ -145,60 +137,6 @@ export function EditorPanelContent({ const isLocalFileMode = kind === "local_file"; const editorRenderMode: EditorRenderMode = isLocalFileMode ? "source_code" : "rich_markdown"; - // --- Citation-jump highlight wiring ---------------------------------- - // `EditorPanelContent` is the consumer of `pendingChunkHighlightAtom`: when - // a citation badge is clicked, the badge stages `{documentId, chunkId, - // chunkText}` and opens this panel. We drive Plate's `FindReplacePlugin` - // (registered in every preset) to highlight the cited text natively via - // Slate decorations — no DOM walking, no Range gymnastics. The state - // machine below escalates the document fetch from 2MB → 16MB once if no - // candidate snippet matched in the preview, and surfaces miss outcomes - // via an inline alert. - const pending = useAtomValue(pendingChunkHighlightAtom); - const setPendingHighlight = useSetAtom(pendingChunkHighlightAtom); - const [fetchKey, setFetchKey] = useState(0); - const [maxLengthOverride, setMaxLengthOverride] = useState(null); - const [highlightResult, setHighlightResult] = useState(null); - const editorRef = useRef(null); - const escalatedForRef = useRef(null); - const lastAppliedChunkIdRef = useRef(null); - // Tracks whether a citation highlight is currently decorated in the - // editor. We use a ref (not state) because the click-to-dismiss handler - // runs in a stable callback that would otherwise close over stale state. - const isHighlightActiveRef = useRef(false); - // Once a citation jump targets this doc we have to keep `PlateEditor` - // mounted for the *rest of the doc session* — even after the highlight - // effect clears `pendingChunkHighlightAtom` (which it does as soon as - // the decoration is applied, so a follow-up citation on the same chunk - // can re-trigger). Without this latch, non-editable docs would re-render - // back into `MarkdownViewer` the instant `pending` is released, tearing - // down the Plate decorations and dropping the highlight after a frame. - const [stickyPlateMode, setStickyPlateMode] = useState(false); - - const clearCitationSearch = useCallback(() => { - isHighlightActiveRef.current = false; - const editor = editorRef.current; - if (!editor) return; - try { - editor.setOption(FindReplacePlugin, "search", ""); - editor.api.redecorate(); - } catch (err) { - console.warn("[EditorPanelContent] clearCitationSearch failed:", err); - } - }, []); - - // Dismiss the highlight when the user interacts with the editor surface. - // `onPointerDown` fires before focus / selection changes so the click - // itself feels responsive — the highlight clears in the same event tick - // that places the cursor. No-op when nothing is highlighted, so we don't - // thrash `redecorate` on every click in normal editing. - const handleEditorPointerDown = useCallback(() => { - if (!isHighlightActiveRef.current) return; - clearCitationSearch(); - setHighlightResult(null); - }, [clearCitationSearch]); - - const isCitationTarget = !!pending && !isLocalFileMode && pending.documentId === documentId; const resolveLocalVirtualPath = useCallback( async (candidatePath: string): Promise => { if (!electronAPI?.getAgentFilesystemMounts) { @@ -218,8 +156,6 @@ export function EditorPanelContent({ const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > LARGE_DOCUMENT_THRESHOLD; - // `fetchKey` is an explicit re-fetch trigger (escalation bumps it to force - // a new request even when documentId/searchSpaceId haven't changed). useEffect(() => { const controller = new AbortController(); setIsLoading(true); @@ -231,12 +167,6 @@ export function EditorPanelContent({ setIsEditing(false); initialLoadDone.current = false; changeCountRef.current = 0; - // Clear any in-flight FindReplacePlugin search before the editor - // re-mounts on new content (a fresh editor key is generated below - // from documentId + isEditing, so the previous editor + its - // decorations are about to be discarded anyway, but we belt-and- - // brace here for the case where only `fetchKey` changed). - clearCitationSearch(); const doFetch = async () => { try { @@ -281,11 +211,7 @@ export function EditorPanelContent({ const url = new URL( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); - url.searchParams.set("max_length", String(maxLengthOverride ?? LARGE_DOCUMENT_THRESHOLD)); - // `fetchKey` participates here so biome's noUnusedVariables sees it - // as consumed; bumping it forces a fresh request even when the URL - // is otherwise identical. - if (fetchKey > 0) url.searchParams.set("_n", String(fetchKey)); + url.searchParams.set("max_length", String(LARGE_DOCUMENT_THRESHOLD)); const response = await authenticatedFetch(url.toString(), { method: "GET" }); @@ -331,259 +257,8 @@ export function EditorPanelContent({ resolveLocalVirtualPath, searchSpaceId, title, - fetchKey, - maxLengthOverride, - clearCitationSearch, ]); - // Reset citation-jump bookkeeping whenever the panel switches to a different - // document (or local file). Body only writes setters — the deps are the - // real triggers we want to react to. - // biome-ignore lint/correctness/useExhaustiveDependencies: documentId/localFilePath are intentional triggers. - useEffect(() => { - clearCitationSearch(); - escalatedForRef.current = null; - lastAppliedChunkIdRef.current = null; - setHighlightResult(null); - setMaxLengthOverride(null); - setFetchKey(0); - // Drop sticky Plate mode when the panel moves to a different doc - // — the next doc starts in its preferred render mode (Plate for - // editable, MarkdownViewer for everything else) until/unless a - // citation jump targets it. - setStickyPlateMode(false); - }, [documentId, localFilePath, clearCitationSearch]); - - // Latch sticky Plate mode the first time a citation jump targets this - // doc. We keep it sticky for the remainder of this doc session so the - // highlight effect's `setPendingHighlight(null)` doesn't unmount the - // editor mid-flight (see comment on `stickyPlateMode` declaration). - useEffect(() => { - if (isCitationTarget) setStickyPlateMode(true); - }, [isCitationTarget]); - - // `isEditorReady` is what `useEffect` actually depends on — `editorRef` - // is a ref so changes don't trigger re-runs. We flip this to `true` once - // `PlateEditor` calls back with its live editor instance (its - // `usePlateEditor` value-init runs synchronously, so by the time this - // flips true the markdown is already deserialized into the Slate tree). - const [isEditorReady, setIsEditorReady] = useState(false); - const handleEditorReady = useCallback((editor: PlateEditorInstance | null) => { - console.log("[citation:editor] handleEditorReady", { ready: !!editor }); - editorRef.current = editor; - setIsEditorReady(!!editor); - }, []); - - // --- Citation jump highlight effect ----------------------------------- - // Drives Plate's FindReplacePlugin to highlight the cited chunk: - // 1. Build candidate snippets from the chunk text (first sentence, - // first 8 words, full chunk if short). Plate's decorate runs per- - // block and won't cross block boundaries, so the shorter - // candidates exist to give us something that fits in one - // paragraph / heading. - // 2. For each candidate: setOption('search', ...) → redecorate → - // wait two animation frames for React to flush → query the editor - // DOM for `.${CITATION_HIGHLIGHT_CLASS}`. First hit wins. - // - // Why a className and not a `data-*` attribute? Plate's - // `PlateLeaf` runs its props through `useNodeAttributes`, which - // only forwards `attributes`, `className`, `ref`, and `style` — - // arbitrary `data-*` attributes are silently dropped. `className` - // is the only escape hatch guaranteed to survive into the DOM. - // 3. On hit: smooth-scroll the first match into view, mark the - // highlight active (so a click inside the editor can dismiss it), - // release the pending atom. - // 4. On terminal miss: if the doc was truncated and we haven't - // escalated yet, bump the fetch's `max_length` to the citation - // cap and re-fetch — the post-refetch render will re-run this - // effect against the larger preview. Otherwise, release the - // atom and show the miss alert. - useEffect(() => { - console.log("[citation:effect] fired", { - isCitationTarget, - pendingDocId: pending?.documentId, - pendingChunkId: pending?.chunkId, - pendingChunkTextLen: pending?.chunkText?.length, - documentId, - isLocalFileMode, - isEditing, - hasMarkdown: !!editorDoc?.source_markdown, - markdownLen: editorDoc?.source_markdown?.length, - truncated: editorDoc?.truncated, - isEditorReady, - editorRefSet: !!editorRef.current, - maxLengthOverride, - }); - if (!isCitationTarget || !pending) { - console.log("[citation:effect] guard ✗ no citation target / no pending"); - return; - } - if (isLocalFileMode || isEditing) { - console.log("[citation:effect] guard ✗ localFileMode/editing"); - return; - } - if (!editorDoc?.source_markdown) { - console.log("[citation:effect] guard ✗ source_markdown not ready"); - return; - } - if (!isEditorReady) { - console.log("[citation:effect] guard ✗ editor not ready yet"); - return; - } - const editor = editorRef.current; - if (!editor) { - console.log("[citation:effect] guard ✗ editorRef.current is null"); - return; - } - - if (lastAppliedChunkIdRef.current !== pending.chunkId) { - lastAppliedChunkIdRef.current = pending.chunkId; - } - - let cancelled = false; - - const finishMiss = () => { - console.log("[citation:effect] terminal miss — no candidate matched"); - try { - editor.setOption(FindReplacePlugin, "search", ""); - editor.api.redecorate(); - } catch (err) { - console.warn("[EditorPanelContent] reset search after miss failed:", err); - } - const canEscalate = - editorDoc.truncated === true && - (maxLengthOverride ?? LARGE_DOCUMENT_THRESHOLD) < CITATION_MAX_LENGTH && - escalatedForRef.current !== pending.chunkId; - console.log("[citation:effect] miss decision", { - truncated: editorDoc.truncated, - currentMaxLength: maxLengthOverride ?? LARGE_DOCUMENT_THRESHOLD, - canEscalate, - }); - if (canEscalate) { - escalatedForRef.current = pending.chunkId; - setMaxLengthOverride(CITATION_MAX_LENGTH); - setFetchKey((k) => k + 1); - // Keep the atom set so the post-refetch render re-runs. - return; - } - setHighlightResult("miss"); - setPendingHighlight(null); - }; - - const tryCandidates = async () => { - const candidates = buildCitationSearchCandidates(pending.chunkText); - console.log("[citation:effect] candidates built", { - count: candidates.length, - previews: candidates.map((c) => c.slice(0, 60)), - }); - if (candidates.length === 0) { - if (!cancelled) finishMiss(); - return; - } - // Resolve the editor's rendered DOM root via Slate's stable - // `[data-slate-editor="true"]` attribute (set by slate-react's - // ``). Scoping queries to this root prevents - // `` elements rendered elsewhere on the page (e.g. chat - // search-highlight leaves in another mounted PlateEditor) from - // being mistaken for citation hits. - const editorRoot = document.querySelector('[data-slate-editor="true"]'); - console.log("[citation:effect] editor root", { - hasRoot: !!editorRoot, - }); - const root: ParentNode = editorRoot ?? document; - - for (let i = 0; i < candidates.length; i++) { - const candidate = candidates[i]; - if (cancelled) return; - try { - editor.setOption(FindReplacePlugin, "search", candidate); - editor.api.redecorate(); - console.log(`[citation:effect] try #${i} setOption + redecorate`, { - len: candidate.length, - preview: candidate.slice(0, 80), - }); - } catch (err) { - console.warn("[EditorPanelContent] setOption/redecorate failed:", err); - continue; - } - // Two rAFs: first lets Slate flush its onChange, second lets - // React commit the decoration leaves into the DOM. - await new Promise((resolve) => - requestAnimationFrame(() => requestAnimationFrame(() => resolve())) - ); - if (cancelled) return; - // Primary probe: by our stable class on the rendered . - let el = root.querySelector(`.${CITATION_HIGHLIGHT_CLASS}`); - const classMarkCount = root.querySelectorAll(`.${CITATION_HIGHLIGHT_CLASS}`).length; - // Diagnostic fallback: any inside the editor root. - // If we ever see allMarks > 0 but classMarkCount === 0, - // the className was stripped again and we need to revisit - // `useNodeAttributes` filtering. - const allMarkCount = root.querySelectorAll("mark").length; - if (!el && allMarkCount > 0) { - el = root.querySelector("mark"); - } - console.log(`[citation:effect] try #${i} DOM probe`, { - foundEl: !!el, - classMarkCount, - allMarkCount, - usedFallback: !!el && classMarkCount === 0, - }); - if (el) { - try { - el.scrollIntoView({ block: "center", behavior: "smooth" }); - } catch { - el.scrollIntoView(); - } - isHighlightActiveRef.current = true; - setHighlightResult("exact"); - console.log(`[citation:effect] ✓ exact via candidate #${i} — atom released`); - // No auto-clear timer — the highlight is intentionally - // permanent until the user clicks inside the editor (see - // `handleEditorPointerDown`) or another dismissal trigger - // fires (doc switch, edit-mode toggle, panel unmount, - // next citation jump). Sticky Plate mode keeps the - // editor mounted after the atom clears. - setPendingHighlight(null); - return; - } - } - if (!cancelled) finishMiss(); - }; - - void tryCandidates(); - - return () => { - cancelled = true; - }; - }, [ - isCitationTarget, - pending, - documentId, - editorDoc?.source_markdown, - editorDoc?.truncated, - isLocalFileMode, - isEditing, - isEditorReady, - maxLengthOverride, - clearCitationSearch, - setPendingHighlight, - ]); - - // Cleanup any active highlight on unmount. - useEffect(() => { - return () => clearCitationSearch(); - }, [clearCitationSearch]); - - // Toggling into edit mode swaps Plate out of readOnly. Clear the citation - // search so stale leaves don't linger in the editing surface. - useEffect(() => { - if (isEditing) { - clearCitationSearch(); - setHighlightResult(null); - } - }, [isEditing, clearCitationSearch]); - useEffect(() => { return () => { if (copyResetTimeoutRef.current) { @@ -617,7 +292,7 @@ export function EditorPanelContent({ }, [editorDoc?.source_markdown]); const handleSave = useCallback( - async (_options?: { silent?: boolean }) => { + async (options?: { silent?: boolean }) => { setSaving(true); try { if (isLocalFileMode) { @@ -668,11 +343,15 @@ export function EditorPanelContent({ setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); setEditedMarkdown(null); - toast.success("Document saved! Reindexing in background..."); + if (!options?.silent) { + toast.success("Document saved! Reindexing in background..."); + } return true; } catch (err) { console.error("Error saving document:", err); - toast.error(err instanceof Error ? err.message : "Failed to save document"); + if (!options?.silent) { + toast.error(err instanceof Error ? err.message : "Failed to save document"); + } return false; } finally { setSaving(false); @@ -693,15 +372,11 @@ export function EditorPanelContent({ EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && !isLargeDocument : false; - // Use PlateEditor for any of: - // - Editable doc types (FILE/NOTE) — existing editing UX. - // - Active citation jump in flight (`isCitationTarget`) — covers the - // mount in the very first render where the atom is set but the - // sticky effect hasn't fired yet. - // - Sticky Plate mode latched on a previous citation jump — keeps - // the editor mounted (with its decorations) after the highlight - // effect clears the atom. Resets when the doc changes. - const renderInPlateEditor = isEditableType || isCitationTarget || stickyPlateMode; + // Render through PlateEditor for editable doc types (FILE/NOTE). + // Everything else (large docs, non-editable types) falls back to the + // lightweight `MarkdownViewer` — Plate is heavy on multi-MB docs and + // non-editable types don't benefit from its editing UX. + const renderInPlateEditor = isEditableType; const hasUnsavedChanges = editedMarkdown !== null; const showDesktopHeader = !!onClose; const showEditingActions = isEditableType && isEditing; @@ -744,36 +419,6 @@ export function EditorPanelContent({ } }, [documentId, editorDoc?.title, searchSpaceId]); - // We no longer surface an "approximate" status — Plate's FindReplacePlugin - // either decorates an exact match or it doesn't, and the candidate snippet - // strategy (first sentence → first 8 words → full chunk) means we either - // land on the citation start or fall through to the miss alert. - const showMissAlert = isCitationTarget && highlightResult === "miss"; - - const citationAlerts = showMissAlert && ( - - - - Cited section couldn't be located in this view. - {editorDoc?.truncated && ( - - )} - - - ); - const largeDocAlert = isLargeDocument && !isLocalFileMode && editorDoc && ( @@ -1002,30 +647,17 @@ export function EditorPanelContent({ }} />
- ) : isLargeDocument && !isLocalFileMode && !isCitationTarget ? ( - // Large doc, no active citation — fast Streamdown preview - // + download CTA. We only fall back to MarkdownViewer here - // because Plate is heavy on multi-MB docs and the user - // isn't waiting on a specific citation to render. + ) : isLargeDocument && !isLocalFileMode ? ( + // Large doc — fast Streamdown preview + download CTA. + // Plate is heavy on multi-MB docs.
{largeDocAlert}
) : renderInPlateEditor ? ( - // Editable doc (FILE/NOTE) OR active citation jump (any - // doc type). The citation path uses Plate's - // FindReplacePlugin for native, decoration-based - // highlighting — see the citation-jump highlight effect - // above for how `editorRef` and `handleEditorReady` are - // wired. + // Editable doc (FILE/NOTE) — Plate editing UX.
- {(citationAlerts || (isLargeDocument && isCitationTarget && !isLocalFileMode)) && ( -
- {isLargeDocument && isCitationTarget && largeDocAlert} - {citationAlerts} -
- )} -
+
diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index eef18ef6a..7f12d3cae 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -12,10 +12,7 @@ import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { Editor, EditorContainer } from "@/components/ui/editor"; -/** Live editor instance returned by `usePlateEditor`. Exposed via the - * `onEditorReady` prop so callers (e.g. `EditorPanelContent`) can drive - * plugin options imperatively — most notably setting - * `FindReplacePlugin`'s `search` option for citation-jump highlights. */ +/** Live editor instance returned by `usePlateEditor`. */ export type PlateEditorInstance = ReturnType; export interface PlateEditorProps { @@ -68,15 +65,6 @@ export interface PlateEditorProps { * without modifying the core editor component. */ extraPlugins?: AnyPluginConfig[]; - /** - * Called whenever the live editor instance (re)mounts, with `null` on - * unmount. Used by callers that need to drive plugin options imperatively - * — e.g. `EditorPanelContent` setting `FindReplacePlugin`'s `search` - * option for citation-jump highlights. The callback is invoked exactly - * once per editor lifetime (the parent's `key` prop forces a fresh - * editor when needed, e.g. on edit-mode toggle). - */ - onEditorReady?: (editor: PlateEditorInstance | null) => void; } function PlateEditorContent({ @@ -115,7 +103,6 @@ export function PlateEditor({ defaultEditing = false, preset = "full", extraPlugins = [], - onEditorReady, }: PlateEditorProps) { const lastMarkdownRef = useRef(markdown); const lastHtmlRef = useRef(html); @@ -172,21 +159,6 @@ export function PlateEditor({ : undefined, }); - // Expose the live editor instance to imperative callers (e.g. citation - // jump highlights). We deliberately don't depend on `onEditorReady` - // itself in the cleanup closure — callers commonly pass an arrow that - // closes over a stable ref setter, but if they pass a freshly-bound - // callback per render, the `onEditorReady?.(editor)` re-fires which is - // idempotent for ref-style setters. - const onEditorReadyRef = useRef(onEditorReady); - useEffect(() => { - onEditorReadyRef.current = onEditorReady; - }, [onEditorReady]); - useEffect(() => { - onEditorReadyRef.current?.(editor); - return () => onEditorReadyRef.current?.(null); - }, [editor]); - // Update editor content when html prop changes externally useEffect(() => { if (html !== undefined && html !== lastHtmlRef.current) { diff --git a/surfsense_web/components/editor/presets.ts b/surfsense_web/components/editor/presets.ts index 49f53ecf1..c207b5e56 100644 --- a/surfsense_web/components/editor/presets.ts +++ b/surfsense_web/components/editor/presets.ts @@ -1,6 +1,5 @@ "use client"; -import { FindReplacePlugin } from "@platejs/find-replace"; import type { AnyPluginConfig } from "platejs"; import { TrailingBlockPlugin } from "platejs"; @@ -18,30 +17,6 @@ import { SelectionKit } from "@/components/editor/plugins/selection-kit"; import { SlashCommandKit } from "@/components/editor/plugins/slash-command-kit"; import { TableKit } from "@/components/editor/plugins/table-kit"; import { ToggleKit } from "@/components/editor/plugins/toggle-kit"; -import { SearchHighlightLeaf } from "@/components/ui/search-highlight-node"; - -/** - * Citation-jump highlighter. Re-uses Plate's built-in `FindReplacePlugin` - * (decorate-only, no editing surface) to drive the "scroll-to-cited-text" - * UX in `EditorPanelContent`. We register it in every preset because: - * - Decorate is a no-op when `search` is empty (single getOptions() check - * per block), so cost is effectively zero for non-citation viewers. - * - Keeping it preset-agnostic means citations work whether the doc is - * opened in editable (`full`) or pure-viewer (`readonly`) modes. - * - * The parent component drives `setOption(FindReplacePlugin, 'search', ...)` - * + `editor.api.redecorate()` to trigger highlights, then queries the - * editor DOM for `.citation-highlight-leaf` to scroll the first match - * into view. (We can't use a `data-*` attribute here — Plate's - * `PlateLeaf` runs props through `useNodeAttributes`, which only forwards - * `attributes`, `className`, `ref`, `style`; arbitrary `data-*` props are - * silently dropped.) See `components/ui/search-highlight-node.tsx` for - * the leaf component and `CITATION_HIGHLIGHT_CLASS` constant. - */ -const CitationFindReplacePlugin = FindReplacePlugin.configure({ - options: { search: "" }, - render: { node: SearchHighlightLeaf }, -}); /** * Full preset – every plugin kit enabled. @@ -63,7 +38,6 @@ export const fullPreset: AnyPluginConfig[] = [ ...AutoformatKit, ...DndKit, TrailingBlockPlugin, - CitationFindReplacePlugin, ]; /** @@ -78,7 +52,6 @@ export const minimalPreset: AnyPluginConfig[] = [ ...LinkKit, ...AutoformatKit, TrailingBlockPlugin, - CitationFindReplacePlugin, ]; /** @@ -95,7 +68,6 @@ export const readonlyPreset: AnyPluginConfig[] = [ ...CalloutKit, ...ToggleKit, ...MathKit, - CitationFindReplacePlugin, ]; /** All available preset names */ diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index 04bae010c..3481eec28 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -6,6 +6,7 @@ import dynamic from "next/dynamic"; import { startTransition, useEffect } from "react"; import { closeHitlEditPanelAtom, hitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { closeReportPanelAtom, reportPanelAtom } from "@/atoms/chat/report-panel.atom"; +import { citationPanelAtom, closeCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { documentsSidebarOpenAtom } from "@/atoms/documents/ui.atoms"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { rightPanelCollapsedAtom, rightPanelTabAtom } from "@/atoms/layout/right-panel.atom"; @@ -21,6 +22,14 @@ const EditorPanelContent = dynamic( { ssr: false, loading: () => null } ); +const CitationPanelContent = dynamic( + () => + import("@/components/citation-panel/citation-panel").then((m) => ({ + default: m.CitationPanelContent, + })), + { ssr: false, loading: () => null } +); + const HitlEditPanelContent = dynamic( () => import("@/components/hitl-edit-panel/hitl-edit-panel").then((m) => ({ @@ -69,12 +78,14 @@ export function RightPanelExpandButton() { const reportState = useAtomValue(reportPanelAtom); const editorState = useAtomValue(editorPanelAtom); const hitlEditState = useAtomValue(hitlEditPanelAtom); + const citationState = useAtomValue(citationPanelAtom); const reportOpen = reportState.isOpen && !!reportState.reportId; const editorOpen = editorState.isOpen && (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; - const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen; + const citationOpen = citationState.isOpen && citationState.chunkId != null; + const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen || citationOpen; if (!collapsed || !hasContent) return null; @@ -98,7 +109,13 @@ export function RightPanelExpandButton() { ); } -const PANEL_WIDTHS = { sources: 420, report: 640, editor: 640, "hitl-edit": 640 } as const; +const PANEL_WIDTHS = { + sources: 420, + report: 640, + editor: 640, + "hitl-edit": 640, + citation: 560, +} as const; export function RightPanel({ documentsPanel }: RightPanelProps) { const [activeTab] = useAtom(rightPanelTabAtom); @@ -108,6 +125,8 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { const closeEditor = useSetAtom(closeEditorPanelAtom); const hitlEditState = useAtomValue(hitlEditPanelAtom); const closeHitlEdit = useSetAtom(closeHitlEditPanelAtom); + const citationState = useAtomValue(citationPanelAtom); + const closeCitation = useSetAtom(closeCitationPanelAtom); const [collapsed, setCollapsed] = useAtom(rightPanelCollapsedAtom); const documentsOpen = documentsPanel?.open ?? false; @@ -116,37 +135,59 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { editorState.isOpen && (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; + const citationOpen = citationState.isOpen && citationState.chunkId != null; useEffect(() => { - if (!reportOpen && !editorOpen && !hitlEditOpen) return; + if (!reportOpen && !editorOpen && !hitlEditOpen && !citationOpen) return; const handleKeyDown = (e: KeyboardEvent) => { if (e.key === "Escape") { if (hitlEditOpen) closeHitlEdit(); + else if (citationOpen) closeCitation(); else if (editorOpen) closeEditor(); else if (reportOpen) closeReport(); } }; document.addEventListener("keydown", handleKeyDown); return () => document.removeEventListener("keydown", handleKeyDown); - }, [reportOpen, editorOpen, hitlEditOpen, closeReport, closeEditor, closeHitlEdit]); + }, [ + reportOpen, + editorOpen, + hitlEditOpen, + citationOpen, + closeReport, + closeEditor, + closeHitlEdit, + closeCitation, + ]); - const isVisible = (documentsOpen || reportOpen || editorOpen || hitlEditOpen) && !collapsed; + const isVisible = + (documentsOpen || reportOpen || editorOpen || hitlEditOpen || citationOpen) && !collapsed; let effectiveTab = activeTab; if (effectiveTab === "hitl-edit" && !hitlEditOpen) { - effectiveTab = editorOpen ? "editor" : reportOpen ? "report" : "sources"; - } else if (effectiveTab === "editor" && !editorOpen) { - effectiveTab = reportOpen ? "report" : "sources"; - } else if (effectiveTab === "report" && !reportOpen) { - effectiveTab = editorOpen ? "editor" : "sources"; - } else if (effectiveTab === "sources" && !documentsOpen) { - effectiveTab = hitlEditOpen - ? "hitl-edit" + effectiveTab = citationOpen + ? "citation" : editorOpen ? "editor" : reportOpen ? "report" : "sources"; + } else if (effectiveTab === "citation" && !citationOpen) { + effectiveTab = editorOpen ? "editor" : reportOpen ? "report" : "sources"; + } else if (effectiveTab === "editor" && !editorOpen) { + effectiveTab = citationOpen ? "citation" : reportOpen ? "report" : "sources"; + } else if (effectiveTab === "report" && !reportOpen) { + effectiveTab = citationOpen ? "citation" : editorOpen ? "editor" : "sources"; + } else if (effectiveTab === "sources" && !documentsOpen) { + effectiveTab = hitlEditOpen + ? "hitl-edit" + : citationOpen + ? "citation" + : editorOpen + ? "editor" + : reportOpen + ? "report" + : "sources"; } const targetWidth = PANEL_WIDTHS[effectiveTab]; @@ -205,6 +246,11 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { />
)} + {effectiveTab === "citation" && citationOpen && citationState.chunkId != null && ( +
+ +
+ )}
); diff --git a/surfsense_web/components/ui/search-highlight-node.tsx b/surfsense_web/components/ui/search-highlight-node.tsx deleted file mode 100644 index e3f316cce..000000000 --- a/surfsense_web/components/ui/search-highlight-node.tsx +++ /dev/null @@ -1,45 +0,0 @@ -"use client"; - -import type { PlateLeafProps } from "platejs/react"; -import { PlateLeaf } from "platejs/react"; - -/** - * Stable class name used to identify Plate-rendered citation highlight - * leaves in the DOM. We can't use a `data-*` attribute here — Plate's - * `PlateLeaf` runs its props through `useNodeAttributes`, which only - * forwards `attributes`, `className`, `ref`, and `style` to the rendered - * element; arbitrary `data-*` props are silently dropped (verified - * against `@platejs/core/dist/react/index.js` v52). So `className` is - * the only escape hatch that's guaranteed to survive into the DOM. - */ -export const CITATION_HIGHLIGHT_CLASS = "citation-highlight-leaf"; - -/** - * Leaf rendered for ranges decorated by `@platejs/find-replace`'s - * `FindReplacePlugin`. We re-purpose that plugin to drive the citation-jump - * highlight: when a citation is staged, the parent sets the plugin's `search` - * option to a snippet of the chunk text and Plate decorates every match with - * `searchHighlight: true`. This component renders those decorations as a - * `` tagged with `CITATION_HIGHLIGHT_CLASS` so the parent can: - * 1. Query the first match in DOM order to scroll it into view. - * 2. Detect the active-highlight state without a separate React ref. - * - * The highlight is **persistent** — it does not auto-fade. The parent in - * `EditorPanelContent` clears it by setting the plugin's `search` option - * back to "" when one of: (a) the user clicks anywhere inside the editor, - * (b) the panel switches to a different document, (c) the user toggles - * into edit mode, (d) another citation jump is staged, (e) the panel - * unmounts. We use a brief entrance pulse (`citation-flash-in`, see - * `globals.css`) purely to draw the eye after `scrollIntoView` lands. - */ -export function SearchHighlightLeaf(props: PlateLeafProps) { - return ( - - {props.children} - - ); -} diff --git a/surfsense_web/lib/citation-search.ts b/surfsense_web/lib/citation-search.ts deleted file mode 100644 index f80f13076..000000000 --- a/surfsense_web/lib/citation-search.ts +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Snippet generation for the citation-jump highlight, driven by Plate's - * `FindReplacePlugin`. The plugin runs `decorate` per-block and only matches - * within blocks whose children are all `Text` nodes (so it crosses inline - * marks like bold/italic but **not** block boundaries, and a block that - * contains even one inline element such as a link is silently skipped). - * That means a full chunk that spans heading + paragraph won't match as a - * single string — we have to pick a shorter snippet that fits inside one - * rendered block. - * - * `buildCitationSearchCandidates` returns search strings ordered from - * "most-specific anchor" to "broadest fallback": - * 1. First sentence of the chunk (capped at `FIRST_SENTENCE_MAX`). - * 2. First `FIRST_PHRASE_WORDS` words. - * 3. Each non-trivial line of the chunk, in source order — gives us a - * separate attempt for each rendered block, so a heading line with - * an inline link doesn't doom the whole jump. - * 4. Full chunk (only if it's already short enough to plausibly fit - * inside one block). - * - * The caller tries each candidate in turn — set the plugin's `search` - * option, `editor.api.redecorate()`, then check the editor DOM for a - * `.citation-highlight-leaf` element. First candidate that produces one - * wins; subsequent candidates are skipped. - */ - -const FIRST_SENTENCE_MAX = 120; -const FIRST_PHRASE_WORDS = 8; -const MIN_SNIPPET_LENGTH = 6; -const FULL_CHUNK_MAX = FIRST_SENTENCE_MAX * 2; -const MAX_LINE_CANDIDATES = 6; -const LINE_CANDIDATE_MAX = FIRST_SENTENCE_MAX; - -function normalizeWhitespace(input: string): string { - return input.replace(/\s+/g, " ").trim(); -} - -/** - * Strip the markdown syntax that won't survive into the rendered editor's - * plain text, so the chunk text (which comes back from the indexer as raw - * source markdown) can be matched against the literal text values stored - * in Plate's Slate tree. - * - * Order matters: handle multi-char and "container" syntax before single- - * char emphasis, otherwise `**text**` collapses to `*text*` first. - * - * Heuristic only — we don't aim to be a full markdown parser, just to - * remove the common markers (`**bold**`, `[text](url)`, `# headings`, - * `- list`, etc.) that show up in connector-doc chunks and would break - * literal substring search. - */ -export function stripMarkdownForMatch(input: string): string { - let s = input; - s = s.replace(/```[a-z0-9_+-]*\n?([\s\S]*?)```/gi, (_, body: string) => body); - s = s.replace(//g, " "); - s = s.replace(/!\[([^\]]*)\]\([^)]*\)/g, "$1"); - s = s.replace(/!\[([^\]]*)\]\[[^\]]*\]/g, "$1"); - s = s.replace(/\[([^\]]+)\]\([^)]*\)/g, "$1"); - s = s.replace(/\[([^\]]+)\]\[[^\]]*\]/g, "$1"); - s = s.replace(/<((?:https?|mailto):[^>\s]+)>/g, "$1"); - s = s.replace(/`+([^`\n]+?)`+/g, "$1"); - s = s.replace(/(\*\*|__)([\s\S]+?)\1/g, "$2"); - s = s.replace(/(?+[ \t]?/gm, ""); - s = s.replace(/^[ \t]*[-*+][ \t]+/gm, ""); - s = s.replace(/^[ \t]*\d+\.[ \t]+/gm, ""); - s = s.replace(/^[ \t]{0,3}(?:[-*_])(?:[ \t]*[-*_]){2,}[ \t]*$/gm, ""); - s = s.replace(/^[ \t]*\|?(?:[ \t]*:?-+:?[ \t]*\|)+[ \t]*:?-+:?[ \t]*\|?[ \t]*$/gm, ""); - s = s.replace(/\\([\\`*_{}[\]()#+\-.!~>])/g, "$1"); - return s; -} - -export function buildCitationSearchCandidates(rawText: string): string[] { - if (!rawText) return []; - const stripped = stripMarkdownForMatch(rawText); - const normalized = normalizeWhitespace(stripped); - if (normalized.length < MIN_SNIPPET_LENGTH) return []; - - const out: string[] = []; - const seen = new Set(); - const push = (s: string) => { - const t = normalizeWhitespace(s); - if (t.length >= MIN_SNIPPET_LENGTH && !seen.has(t)) { - out.push(t); - seen.add(t); - } - }; - - const sentenceMatch = normalized.match(/^[^.!?]+[.!?]/); - if (sentenceMatch) { - const sentence = sentenceMatch[0]; - push(sentence.length > FIRST_SENTENCE_MAX ? sentence.slice(0, FIRST_SENTENCE_MAX) : sentence); - } else if (normalized.length > FIRST_SENTENCE_MAX) { - push(normalized.slice(0, FIRST_SENTENCE_MAX)); - } - - const words = normalized.split(" ").filter(Boolean); - if (words.length > FIRST_PHRASE_WORDS) { - push(words.slice(0, FIRST_PHRASE_WORDS).join(" ")); - } - - // Per-line candidates: each chunk line is roughly one block in the - // rendered editor. Trying them in order gives us a separate decorate - // attempt for each block, which matters when the first line is a - // heading containing a link (Plate's `FindReplacePlugin` will skip - // any block whose children aren't all text nodes). - const rawLines = stripped.split(/\r?\n/); - let lineCount = 0; - for (const line of rawLines) { - if (lineCount >= MAX_LINE_CANDIDATES) break; - const trimmed = normalizeWhitespace(line); - if (trimmed.length < MIN_SNIPPET_LENGTH) continue; - push(trimmed.length > LINE_CANDIDATE_MAX ? trimmed.slice(0, LINE_CANDIDATE_MAX) : trimmed); - lineCount++; - } - - if (normalized.length <= FULL_CHUNK_MAX) { - push(normalized); - } - - return out; -} diff --git a/surfsense_web/package.json b/surfsense_web/package.json index 665490e4f..41175daeb 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -36,7 +36,6 @@ "@platejs/code-block": "^52.0.11", "@platejs/combobox": "^52.0.15", "@platejs/dnd": "^52.0.11", - "@platejs/find-replace": "^52.3.10", "@platejs/floating": "^52.0.11", "@platejs/indent": "^52.0.11", "@platejs/link": "^52.0.11", diff --git a/surfsense_web/pnpm-lock.yaml b/surfsense_web/pnpm-lock.yaml index a1a7bea12..b1730e842 100644 --- a/surfsense_web/pnpm-lock.yaml +++ b/surfsense_web/pnpm-lock.yaml @@ -53,9 +53,6 @@ importers: '@platejs/dnd': specifier: ^52.0.11 version: 52.0.11(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dnd-html5-backend@16.0.1)(react-dnd@16.0.1(@types/node@20.19.33)(@types/react@19.2.14)(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@platejs/find-replace': - specifier: ^52.3.10 - version: 52.3.10(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@platejs/floating': specifier: ^52.0.11 version: 52.0.11(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -2830,13 +2827,6 @@ packages: react-dnd-html5-backend: '>=14.0.0' react-dom: '>=18.0.0' - '@platejs/find-replace@52.3.10': - resolution: {integrity: sha512-V/MOMMUYxHfEn/skd2+YO213xSATFDVsl8FzVzVRV/XaxwwVefH2EPD1lAVIvmYjennTVTTsHHtEI9K9iOsEaA==} - peerDependencies: - platejs: '>=52.0.11' - react: '>=18.0.0' - react-dom: '>=18.0.0' - '@platejs/floating@52.0.11': resolution: {integrity: sha512-ApNpw4KWml+kuK+XTTpji+f/7GxTR4nRzlnfJMvGBrJpLPQ4elS5MABm3oUi81DZn+aub5HvsyH7UqCw7F76IA==} peerDependencies: @@ -11115,13 +11105,6 @@ snapshots: react-dnd-html5-backend: 16.0.1 react-dom: 19.2.4(react@19.2.4) - '@platejs/find-replace@52.3.10(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': - dependencies: - platejs: 52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)) - react: 19.2.4 - react-compiler-runtime: 1.0.0(react@19.2.4) - react-dom: 19.2.4(react@19.2.4) - '@platejs/floating@52.0.11(platejs@52.0.17(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(use-sync-external-store@1.6.0(react@19.2.4)))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: '@floating-ui/core': 1.7.4 From f9b5367754c5e07a586b5a318ac06245b3d10846 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Tue, 28 Apr 2026 23:52:37 -0700 Subject: [PATCH 225/299] chore: cleaned comments slop --- surfsense_backend/.env.example | 14 +- .../versions/130_add_agent_action_log.py | 6 +- .../versions/131_add_document_revisions.py | 2 +- .../132_add_agent_permission_rules.py | 9 +- .../app/agents/new_chat/chat_deepagent.py | 125 +++++++++--------- .../app/agents/new_chat/errors.py | 8 +- .../app/agents/new_chat/feature_flags.py | 31 ++--- .../agents/new_chat/middleware/busy_mutex.py | 15 ++- .../agents/new_chat/middleware/compaction.py | 19 +-- .../new_chat/middleware/context_editing.py | 18 +-- .../new_chat/middleware/dedup_tool_calls.py | 10 +- .../agents/new_chat/middleware/doom_loop.py | 22 +-- .../new_chat/middleware/knowledge_search.py | 21 +-- .../new_chat/middleware/noop_injection.py | 28 ++-- .../agents/new_chat/middleware/otel_span.py | 6 +- .../agents/new_chat/middleware/permission.py | 25 ++-- .../agents/new_chat/middleware/retry_after.py | 14 +- .../new_chat/middleware/tool_call_repair.py | 19 +-- .../app/agents/new_chat/permissions.py | 9 +- .../app/agents/new_chat/plugin_loader.py | 9 +- .../new_chat/plugins/year_substituter.py | 10 +- .../app/agents/new_chat/prompts/composer.py | 21 ++- .../app/agents/new_chat/subagents/__init__.py | 15 ++- .../app/agents/new_chat/system_prompt.py | 15 ++- .../app/agents/new_chat/tools/invalid_tool.py | 5 +- .../app/agents/new_chat/tools/registry.py | 6 +- surfsense_backend/app/observability/otel.py | 4 +- .../app/routes/agent_revert_route.py | 8 +- .../agents/new_chat/prompts/test_composer.py | 2 +- .../unit/agents/new_chat/test_otel_span.py | 2 +- .../unit/agents/new_chat/test_permissions.py | 2 +- .../agents/new_chat/test_plugin_loader.py | 2 +- .../tests/unit/observability/test_otel.py | 2 +- .../unit/services/test_revert_service.py | 2 +- 34 files changed, 274 insertions(+), 232 deletions(-) diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index e133a2bc5..c1bfcc538 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -250,12 +250,12 @@ LANGSMITH_PROJECT=surfsense # ============================================================================= -# OPTIONAL: New-chat agent feature flags (OpenCode-port) +# OPTIONAL: New-chat agent feature flags # ============================================================================= # Master kill-switch — when true, every flag below is forced OFF. # SURFSENSE_DISABLE_NEW_AGENT_STACK=false -# Tier 1 — Agent quality +# Agent quality # SURFSENSE_ENABLE_CONTEXT_EDITING=false # SURFSENSE_ENABLE_COMPACTION_V2=false # SURFSENSE_ENABLE_RETRY_AFTER=false @@ -265,24 +265,24 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_TOOL_CALL_REPAIR=false # SURFSENSE_ENABLE_DOOM_LOOP=false # leave OFF until UI handles permission='doom_loop' -# Tier 2 — Safety +# Safety # SURFSENSE_ENABLE_PERMISSION=false # SURFSENSE_ENABLE_BUSY_MUTEX=false # SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call -# Tier 3b — Observability (also requires OTEL_EXPORTER_OTLP_ENDPOINT) +# Observability — OTel (also requires OTEL_EXPORTER_OTLP_ENDPOINT) # SURFSENSE_ENABLE_OTEL=false -# Tier 4 — Skills + subagents +# Skills + subagents # SURFSENSE_ENABLE_SKILLS=false # SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false # SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false -# Tier 5 — Snapshot / revert +# Snapshot / revert # SURFSENSE_ENABLE_ACTION_LOG=false # SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships -# Tier 6 — Plugins +# Plugins # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names # SURFSENSE_ALLOWED_PLUGINS=year_substituter diff --git a/surfsense_backend/alembic/versions/130_add_agent_action_log.py b/surfsense_backend/alembic/versions/130_add_agent_action_log.py index 2f06b8ddd..f86a8a3b5 100644 --- a/surfsense_backend/alembic/versions/130_add_agent_action_log.py +++ b/surfsense_backend/alembic/versions/130_add_agent_action_log.py @@ -4,8 +4,10 @@ Revision ID: 130 Revises: 129 Create Date: 2026-04-28 -Tier 5.2 in the OpenCode-port plan. Adds the append-only ``agent_action_log`` -table that :class:`ActionLogMiddleware` writes to after every tool call. +Adds the append-only ``agent_action_log`` table that +:class:`ActionLogMiddleware` writes to after every tool call. Each row +optionally carries a ``reverse_descriptor`` payload used by +``POST /api/threads/{thread_id}/revert/{action_id}`` to undo the action. """ from __future__ import annotations diff --git a/surfsense_backend/alembic/versions/131_add_document_revisions.py b/surfsense_backend/alembic/versions/131_add_document_revisions.py index 46c6991b6..95ce0e032 100644 --- a/surfsense_backend/alembic/versions/131_add_document_revisions.py +++ b/surfsense_backend/alembic/versions/131_add_document_revisions.py @@ -4,7 +4,7 @@ Revision ID: 131 Revises: 130 Create Date: 2026-04-28 -Tier 5.1 in the OpenCode-port plan. Adds two snapshot tables: +Adds two snapshot tables that back the per-action revert flow: * ``document_revisions``: pre-mutation snapshot of NOTE/FILE/EXTENSION docs. * ``folder_revisions``: pre-mutation snapshot of folder mkdir/move/delete. diff --git a/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py b/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py index 0e81eacb5..ff5b52e18 100644 --- a/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py +++ b/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py @@ -4,11 +4,10 @@ Revision ID: 132 Revises: 131 Create Date: 2026-04-28 -Tier 2.1 in the OpenCode-port plan. Adds the persistent ``agent_permission_rules`` -table consumed by :class:`PermissionMiddleware` at agent build time. Rules -can be scoped at search-space (``user_id`` / ``thread_id`` NULL), -user-wide (``user_id`` set, ``thread_id`` NULL), or per-thread -(``thread_id`` set). +Adds the persistent ``agent_permission_rules`` table consumed by +:class:`PermissionMiddleware` at agent build time. Rules can be scoped +at search-space (``user_id`` / ``thread_id`` NULL), user-wide +(``user_id`` set, ``thread_id`` NULL), or per-thread (``thread_id`` set). """ from __future__ import annotations diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 3ca44dd4f..bfb94ba2d 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -353,11 +353,12 @@ async def create_surfsense_deep_agent( additional_tools=list(additional_tools) if additional_tools else None, ) - # Tier 1.6: register `invalid` tool. It is dispatched only when - # ToolCallNameRepairMiddleware rewrites a malformed call. We - # intentionally append it AFTER ``build_tools_async`` so it never - # appears in the system-prompt tool list (which is built from the - # registry, not the bound tool list). + # Register the ``invalid`` tool only when tool-call repair is on. It + # is dispatched only when :class:`ToolCallNameRepairMiddleware` + # rewrites a malformed call. We intentionally append it AFTER + # ``build_tools_async`` so it never appears in the system-prompt + # tool list (which is built from the registry, not the bound tool + # list). _flags: AgentFeatureFlags = get_flags() if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in { t.name for t in tools @@ -455,10 +456,10 @@ async def create_surfsense_deep_agent( return agent -# Tier 1.1: tools whose output is too costly / lossy to discard. Keep -# this conservative — anything listed here is *never* pruned by -# ContextEditingMiddleware. The list is filtered against actually-bound -# tool names so disabled connectors don't show up here. +# Tools whose output is too costly / lossy to discard. Keep this +# conservative — anything listed here is *never* pruned by +# :class:`ContextEditingMiddleware`. The list is filtered against +# actually-bound tool names so disabled connectors don't show up here. _PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset( { "generate_report", @@ -485,11 +486,12 @@ def _safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]: return tuple(name for name in _PRUNE_PROTECTED_TOOL_NAMES if name in enabled) -# Tier 2.1 / cleanup: opencode `Permission.disabled` parity. Replaces the -# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE``-based gating with a -# declarative pass over :data:`BUILTIN_TOOLS`. Each tool that declares a -# ``required_connector`` not present in ``available_connectors`` gets a -# deny rule so any execution attempt short-circuits with permission_denied. +# Connector gating: any tool whose ``ToolDefinition.required_connector`` +# isn't actually wired up gets a synthesized permission deny rule so +# execution attempts short-circuit with ``permission_denied`` instead of +# bubbling up provider-specific 401/404 errors. Mirrors OpenCode's +# ``Permission.disabled`` (declarative, per-tool gating) — replaces the +# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE`` substring-heuristic. def _synthesize_connector_deny_rules( *, available_connectors: list[str] | None, @@ -503,11 +505,6 @@ def _synthesize_connector_deny_rules( 1. It is currently bound (``enabled_tool_names``). 2. It declares a ``required_connector``. 3. That connector is *not* in ``available_connectors``. - - This expresses the OpenCode ``Permission.disabled`` semantics - declaratively, replacing the substring-heuristic binary gating - that used to consult the hardcoded ``_CONNECTOR_TYPE_TO_SEARCHABLE`` - map. """ available = set(available_connectors or []) deny: list[Rule] = [] @@ -581,7 +578,7 @@ def _build_compiled_agent_blocking( "middleware": gp_middleware, } - # Tier 4.3: specialized user-facing subagents (explore, report_writer, + # Specialized user-facing subagents (explore, report_writer, # connector_negotiator). Registered through SubAgentMiddleware alongside # the general-purpose spec so the parent's `task` tool can address them # by name. Off by default until the flag flips so existing deployments @@ -629,14 +626,13 @@ def _build_compiled_agent_blocking( # ``wrap_model_call`` ordering: the FIRST middleware in the list is the # OUTERMOST wrapper. To ensure prune executes before summarization, # place ``SpillingContextEditingMiddleware`` before - # ``SurfSenseCompactionMiddleware`` (Tier 1.1 + 1.3). - # Compaction is the canonical token-budget defense after the - # cleanup tier removed ``SafeSummarizationMiddleware``. The Bedrock - # buffer-empty defense is folded into ``SurfSenseCompactionMiddleware``. + # ``SurfSenseCompactionMiddleware``. Compaction is the canonical + # token-budget defense; the Bedrock buffer-empty defense is folded + # into ``SurfSenseCompactionMiddleware``. summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend) _ = flags.enable_compaction_v2 # historical flag; retained for telemetry parity - # Tier 1.1: ContextEditing prune. Trigger at 55% of model_max_input, + # ContextEditing prune. Trigger at 55% of ``max_input_tokens``, # earlier than summarization (~85%). When disabled, no edit runs. context_edit_mw = None if ( @@ -664,7 +660,10 @@ def _build_compiled_agent_blocking( backend_resolver=backend_resolver, ) - # Tier 1.4 / 1.8 / 1.9 / 1.10: built-in retry/fallback/limits. + # Resilience knobs: header-aware retry, model fallback, and + # per-thread / per-run call-count limits. The fallback / limit + # middlewares are vanilla LangChain primitives; ``RetryAfter`` is + # SurfSense's header-aware variant (see its module docstring). retry_mw = ( RetryAfterMiddleware(max_retries=3) if flags.enable_retry_after and not flags.disable_new_agent_stack @@ -700,14 +699,16 @@ def _build_compiled_agent_blocking( else None ) - # Tier 1.5: provider-compat _noop injection. + # Provider-compat ``_noop`` injection (mirrors OpenCode's + # ``llm.ts`` workaround for providers that reject empty assistant + # turns or alternating-role constraints). noop_mw = ( NoopInjectionMiddleware() if flags.enable_compaction_v2 and not flags.disable_new_agent_stack else None ) - # Tier 1.7: tool-call name repair (lowercase + invalid fallback). + # Tool-call name repair (lowercase + ``invalid`` fallback). # # ``registered_tool_names`` MUST cover every tool the model can legitimately # call. That includes the bound ``tools`` list AND every tool provided by @@ -737,18 +738,22 @@ def _build_compiled_agent_blocking( } repair_mw = ToolCallNameRepairMiddleware( registered_tool_names=registered_names, - fuzzy_match_threshold=None, # opencode parity: no fuzzy step + # Disable fuzzy matching to avoid silent rewrites; the + # lowercase + ``invalid`` fallback alone covers >95% of + # observed model errors. + fuzzy_match_threshold=None, ) - # Tier 1.11: doom-loop detector. Off by default until UI handles. + # Doom-loop detector. Off by default until the frontend handles + # ``permission == "doom_loop"`` interrupts. doom_loop_mw = ( DoomLoopMiddleware(threshold=3) if flags.enable_doom_loop and not flags.disable_new_agent_stack else None ) - # Tier 2.1: PermissionMiddleware. Layers, earliest -> latest (last - # match wins per opencode): + # PermissionMiddleware. Layers, earliest -> latest (last match wins, + # same evaluation order as OpenCode's ``permission/index.ts``): # # 1. ``surfsense_defaults`` — single ``allow */*`` rule. SurfSense # already runs per-tool HITL (see ``tools/hitl.py``) for mutating @@ -778,11 +783,11 @@ def _build_compiled_agent_blocking( ], ) - # Tier 5.2: ActionLogMiddleware. Off by default until the - # ``agent_action_log`` table is migrated. When enabled, persists one - # row per tool call with optional reverse_descriptor for - # /api/threads/{thread_id}/revert/{action_id}. Sits inside permission - # so denied calls aren't logged as completions. + # ActionLogMiddleware. Off by default until the ``agent_action_log`` + # table is migrated. When enabled, persists one row per tool call + # with optional reverse_descriptor for + # ``POST /api/threads/{thread_id}/revert/{action_id}``. Sits inside + # ``permission`` so denied calls aren't logged as completions. action_log_mw: ActionLogMiddleware | None = None if ( flags.enable_action_log @@ -804,23 +809,24 @@ def _build_compiled_agent_blocking( ) action_log_mw = None - # Tier 2.2: per-thread busy mutex. + # Per-thread busy mutex (refuse a second concurrent turn on the same + # thread; see :class:`BusyMutexMiddleware` docstring). busy_mutex_mw: BusyMutexMiddleware | None = ( BusyMutexMiddleware() if flags.enable_busy_mutex and not flags.disable_new_agent_stack else None ) - # Tier 3b: OpenTelemetry spans (model.call + tool.call). Lives just - # inside BusyMutex so it spans every retry/fallback attempt of the - # current turn but never wraps a queued/blocked turn. + # OpenTelemetry spans (model.call + tool.call). Lives just inside + # BusyMutex so it spans every retry/fallback attempt of the current + # turn but never wraps a queued/blocked turn. otel_mw: OtelSpanMiddleware | None = ( OtelSpanMiddleware() if flags.enable_otel and not flags.disable_new_agent_stack else None ) - # Tier 6: plugin entry-point loader. Off by default; opt-in via the + # Plugin entry-point loader. Off by default; opt-in via the # ``SURFSENSE_ENABLE_PLUGIN_LOADER`` flag. The allowlist is read from # the ``SURFSENSE_ALLOWED_PLUGINS`` env var (comma-separated). A future # PR can wire it through ``global_llm_config.yaml``. @@ -845,10 +851,10 @@ def _build_compiled_agent_blocking( ) plugin_middlewares = [] - # Tier 4.1: SkillsMiddleware. Loads built-in + space-authored skills - # via a CompositeBackend. Sources are layered: built-in first, space - # last, so a search-space-authored skill of the same name overrides - # the bundled one. + # SkillsMiddleware (deepagents) loads built-in + space-authored + # skills via a CompositeBackend. Sources are layered: built-in first, + # space last, so a search-space-authored skill of the same name + # overrides the bundled one. skills_mw: SkillsMiddleware | None = None if flags.enable_skills and not flags.disable_new_agent_stack: try: @@ -865,7 +871,8 @@ def _build_compiled_agent_blocking( logging.warning("SkillsMiddleware init failed; skipping: %s", exc) skills_mw = None - # Tier 2.5: LLM-driven tool selection for >30 tools. + # LangChain's LLM-driven tool selection — only enabled for stacks + # large enough to need narrowing (>30 tools). selector_mw: LLMToolSelectorMiddleware | None = None if ( flags.enable_llm_tool_selector @@ -934,12 +941,12 @@ def _build_compiled_agent_blocking( ) if filesystem_mode == FilesystemMode.CLOUD else None, - # Tier 4.1: skill loader. Placed before SubAgentMiddleware so - # subagents inherit the same skill metadata (subagent specs reference - # the same source paths via `default_skills_sources()`). + # Skill loader. Placed before SubAgentMiddleware so subagents + # inherit the same skill metadata (subagent specs reference the + # same source paths via ``default_skills_sources()``). skills_mw, SubAgentMiddleware(backend=StateBackend, subagents=subagent_specs), - # Tier 2.5: tool selection (only when >30 tools and flag on). + # Tool selection (only when >30 tools and flag on). selector_mw, # Defensive caps, then prune, then summarize. model_call_limit_mw, @@ -954,19 +961,19 @@ def _build_compiled_agent_blocking( # Tool-call repair must run after model emits but before # permission / dedup / doom-loop interpret the calls. repair_mw, - # Tier 2.1: deny/ask BEFORE the calls are forwarded to tool nodes. + # Permission deny/ask BEFORE the calls are forwarded to tool nodes. permission_mw, doom_loop_mw, - # Tier 5.2: action log sits inside permission so denied calls - # don't appear as completions, and outside dedup so each unique - # tool invocation gets its own row. + # Action log sits inside permission so denied calls don't appear + # as completions, and outside dedup so each unique tool invocation + # gets its own row. action_log_mw, PatchToolCallsMiddleware(), DedupHITLToolCallsMiddleware(agent_tools=list(tools)), - # Tier 6: plugin slot — sits just before AnthropicCache so plugin-side - # transforms see the final tool result and run before any caching - # heuristics. Multiple plugins in declared order; loader filtered by - # the admin allowlist already. + # Plugin slot — sits just before AnthropicCache so plugin-side + # transforms see the final tool result and run before any + # caching heuristics. Multiple plugins in declared order; loader + # filtered by the admin allowlist already. *plugin_middlewares, AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] diff --git a/surfsense_backend/app/agents/new_chat/errors.py b/surfsense_backend/app/agents/new_chat/errors.py index b7bac4536..a17333acc 100644 --- a/surfsense_backend/app/agents/new_chat/errors.py +++ b/surfsense_backend/app/agents/new_chat/errors.py @@ -2,10 +2,10 @@ Typed error taxonomy for the SurfSense agent stack. Used by: -- :class:`RetryAfterMiddleware` (Tier 1.4) — its ``retry_on`` callable - consults the error code to decide whether a retry is appropriate. -- :class:`PermissionMiddleware` (Tier 2.1) — emits - ``code="permission_denied"`` errors when a deny rule trips. +- :class:`RetryAfterMiddleware` — its ``retry_on`` callable consults + the error code to decide whether a retry is appropriate. +- :class:`PermissionMiddleware` — emits ``code="permission_denied"`` + errors when a deny rule trips. - All tools — return :class:`StreamingError` payloads in ``ToolMessage.additional_kwargs["error"]`` so the model and the retry/permission layers share a contract. diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index 89c4fb14f..55525abc5 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -1,9 +1,10 @@ """ Feature flags for the SurfSense new_chat agent stack. -These flags control rollout of OpenCode-pattern middleware ported into -SurfSense. They follow a "default-OFF for risky things, default-ON for -safe upgrades, master kill-switch for everything new" model. +These flags gate the newer agent middleware (some ported from OpenCode, +some sourced from ``langchain.agents.middleware`` / ``deepagents``, some +SurfSense-native). They follow a "default-OFF for risky things, +default-ON for safe upgrades, master kill-switch for everything new" model. All new middleware checks its flag at agent build time. If the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new @@ -57,7 +58,7 @@ class AgentFeatureFlags: # regardless of its env value. Used for rapid rollback. disable_new_agent_stack: bool = False - # Tier 1 — Agent quality + # Agent quality — context budget, retry/limits, name-repair, doom-loop enable_context_editing: bool = False enable_compaction_v2: bool = False enable_retry_after: bool = False @@ -69,26 +70,26 @@ class AgentFeatureFlags: False # Default OFF until UI handles permission='doom_loop' ) - # Tier 2 — Safety + # Safety — permissions, concurrency, tool-set narrowing enable_permission: bool = False # Default OFF for first deploy enable_busy_mutex: bool = False enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost - # Tier 4 — Skills + subagents + # Skills + subagents enable_skills: bool = False enable_specialized_subagents: bool = False enable_kb_planner_runnable: bool = False - # Tier 5 — Snapshot / revert + # Snapshot / revert enable_action_log: bool = False enable_revert_route: bool = ( False # Backend ships before UI; route returns 503 until this flips ) - # Tier 6 — Plugins + # Plugins enable_plugin_loader: bool = False - # Tier 3b — OTel (orthogonal: also requires OTEL_EXPORTER_OTLP_ENDPOINT) + # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) enable_otel: bool = False @classmethod @@ -108,7 +109,7 @@ class AgentFeatureFlags: return cls( disable_new_agent_stack=False, - # Tier 1 + # Agent quality enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), @@ -121,13 +122,13 @@ class AgentFeatureFlags: "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False ), enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), - # Tier 2 + # Safety enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), enable_llm_tool_selector=_env_bool( "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False ), - # Tier 4 + # Skills + subagents enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), enable_specialized_subagents=_env_bool( "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False @@ -135,12 +136,12 @@ class AgentFeatureFlags: enable_kb_planner_runnable=_env_bool( "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False ), - # Tier 5 + # Snapshot / revert enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), - # Tier 6 + # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), - # Tier 3b + # Observability enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index 1d95638d0..c57d85004 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -1,11 +1,16 @@ """ BusyMutexMiddleware — per-thread asyncio lock + cancel token. -Tier 2.2 in the OpenCode-port plan. Mirrors opencode's -``Stream.scoped(AbortController)`` pattern (single-process, in-memory -lock + cooperative cancellation). For multi-worker deployments a -distributed lock backend (Redis or PostgreSQL advisory locks) is a -phase-2 follow-up. +LangChain has no built-in concept of "this thread is already running a +turn — refuse the second concurrent request". Without it, a user +double-clicking "send" or refreshing the page mid-stream can spawn two +turns racing on the same checkpoint, producing duplicated tool calls +and mangled state. + +Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a +single-process, in-memory lock + cooperative cancellation token keyed by +``thread_id``. For multi-worker deployments a distributed lock backend +(Redis or PostgreSQL advisory locks) is a phase-2 follow-up. What this provides: - A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``; diff --git a/surfsense_backend/app/agents/new_chat/middleware/compaction.py b/surfsense_backend/app/agents/new_chat/middleware/compaction.py index b0a1a7ec5..16361e16b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/compaction.py +++ b/surfsense_backend/app/agents/new_chat/middleware/compaction.py @@ -5,21 +5,22 @@ Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware` to add SurfSense-specific behavior: 1. **Structured summary template** (OpenCode-style ``## Goal / Constraints / - Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``). + Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``) + — see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base + ``SummarizationMiddleware`` only ships a freeform "summarize this" + prompt; the structured template is ported from OpenCode's + ``compaction.ts``. 2. **Protect SurfSense-specific SystemMessages** so injected hints (````, ````, ````, ````, ````, ````, ````) are *not* summarized away and are kept verbatim in the post-summary - message list. + message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy + (some message types are part of the agent's contract and must survive + compaction unchanged). 3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string`` (Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage containing only tool_calls and no text, ``content`` can be ``None`` and - ``get_buffer_string`` crashes iterating over ``None``). This used to live in - ``safe_summarization.py``; folded in here. - -This replaces ``app.agents.new_chat.middleware.safe_summarization``. - -Tier 1.3 in the OpenCode-port plan. + ``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific. """ from __future__ import annotations @@ -42,7 +43,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# OpenCode-faithful structured summary template. Mirrors +# Structured summary template ported from OpenCode's # ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a # module-level constant so unit tests can assert on its sections. SURFSENSE_SUMMARY_PROMPT = """ diff --git a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py index 360e3e28f..39bc57c8b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py +++ b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py @@ -1,15 +1,15 @@ """ SpillToBackendEdit + SpillingContextEditingMiddleware. -Mirrors OpenCode's spill-to-disk behavior in -``opencode/packages/opencode/src/tool/truncate.ts``. Before -``ClearToolUsesEdit`` rewrites old ``ToolMessage.content`` to a placeholder, -we capture the full original content and write it to the runtime backend -under ``/tool_outputs/{thread_id}/{message_id}.txt``. The placeholder is -upgraded to ``"[cleared — full output at /tool_outputs/.../{id}.txt; ask the -explore subagent to read it]"`` so the agent can recover it on demand. - -Tier 1.2 in the OpenCode-port plan. +LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content`` +when the context-editing budget triggers, replacing the body with a fixed +placeholder. That's lossy: anything the agent might want to revisit is +gone. The spill-to-disk pattern (originally from OpenCode's +``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune +behavior but writes the full original payload to the runtime backend +under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The +placeholder is then upgraded to point at the spill path so the agent +(or a subagent) can read it back on demand. Why this is a middleware subclass instead of a plain ``ContextEdit``: ``ContextEdit.apply`` is sync, but writing to the backend is async. We diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index 3aff524fe..c55347284 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -9,11 +9,10 @@ the duplicate call is stripped from the AIMessage that gets checkpointed. That means it is also safe across LangGraph ``interrupt()`` boundaries: the removed call will never appear on graph resume. -Dedup-key resolution order (Tier 2.3 / cleanup in the OpenCode-port plan): +Dedup-key resolution order: 1. :class:`ToolDefinition.dedup_key` — callable provided by the registry - entry. This is the canonical mechanism after the cleanup-tier removal - of the legacy ``PRIMARY_ARG`` map. + entry. This is the canonical mechanism. 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name; used by MCP / Composio tools whose schemas the registry doesn't see. @@ -72,9 +71,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] The dedup-resolver map is built from two sources, in priority order: 1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's - ``ToolDefinition.dedup_key`` (Tier 2.3). Receives the args dict - and returns a string signature. This is the canonical mechanism - after the cleanup-tier removal of the legacy ``PRIMARY_ARG`` map. + ``ToolDefinition.dedup_key``. Receives the args dict and returns + a string signature. This is the canonical mechanism. 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name; primarily used by MCP / Composio tools. """ diff --git a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py index 1dde87752..850ecd1d2 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py +++ b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py @@ -1,17 +1,19 @@ """ DoomLoopMiddleware — pattern-based detector for repeated identical tool calls. -Mirrors ``opencode/packages/opencode/src/session/processor.ts`` doom-loop -behavior. When the same tool with the same arguments is called N times -in a row, the agent has likely entered an infinite loop. We surface this -to the user as an interrupt with ``permission="doom_loop"`` so the UI -can render an "Are you stuck? Continue / cancel?" affordance. +LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number +of tool calls per turn — but it can't tell apart "10 distinct, useful +calls" from "the same call 10 times in a row". This middleware fills that +gap with a sliding-window check on tool-call signatures, ported from +OpenCode's ``packages/opencode/src/session/processor.ts``. -Tier 1.11 in the OpenCode-port plan. +When the same tool with the same arguments is called N times in a row, +the agent has likely entered an infinite loop. We surface this to the +user as an interrupt with ``permission="doom_loop"`` so the UI can +render an "Are you stuck? Continue / cancel?" affordance. This ships **OFF by default** until the frontend explicitly handles -``context.permission == "doom_loop"`` interrupts (the plan flips -``SURFSENSE_ENABLE_DOOM_LOOP=true`` once the UI is ready). +``context.permission == "doom_loop"`` interrupts. Wire format: uses SurfSense's existing ``interrupt()`` payload shape (see ``app/agents/new_chat/tools/hitl.py``): @@ -69,7 +71,7 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon Args: threshold: How many consecutive identical signatures count as a - doom loop. Default 3 (opencode parity). + doom loop. Default 3 (matches OpenCode's processor.ts). """ def __init__(self, *, threshold: int = 3) -> None: @@ -182,7 +184,7 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon signatures[-1] if signatures else "", ) - # Tier 3b: interrupt.raised span with permission=doom_loop attribute + # Open an interrupt.raised span with permission=doom_loop attribute # so dashboards can break out doom-loop interrupts from regular # permission asks via the ``interrupt.permission`` attribute. with ot.interrupt_span( diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 08ca8e18b..0820e8c3e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -592,10 +592,11 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] self.available_document_types = available_document_types self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] - # Tier 4.2: build the kb-planner private Runnable ONCE here so we - # don't pay the create_agent compile cost (50-200ms) on every turn. - # Disabled by default behind ``enable_kb_planner_runnable``; when off - # the planner falls back to the legacy ``self.llm.ainvoke`` path. + # Build the kb-planner private Runnable ONCE here so we don't pay + # the ``create_agent`` compile cost (50-200ms) on every turn. + # Disabled by default behind ``enable_kb_planner_runnable``; when + # off the planner falls back to the legacy ``self.llm.ainvoke`` + # path. self._planner: Runnable | None = None self._planner_compile_failed = False @@ -608,9 +609,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] lazily on first call, then memoized via ``self._planner``. The compiled agent is constructed without tools — the planner's - contract is "answer with structured JSON" — but with ``RetryAfter`` - + the OpenCode-port retry/limit middleware so it shares the parent - agent's resilience guarantees. + contract is "answer with structured JSON" — but it inherits the + :class:`RetryAfterMiddleware` so transient rate-limit errors + from the planner LLM call don't fail the whole turn. """ if self._planner is not None or self._planner_compile_failed: return self._planner @@ -658,9 +659,9 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] loop = asyncio.get_running_loop() t0 = loop.time() - # Tier 4.2: prefer the compiled-once planner Runnable when enabled; - # otherwise fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` - # tag is preserved on both paths so ``_stream_agent_events`` still + # Prefer the compiled-once planner Runnable when enabled; otherwise + # fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag + # is preserved on both paths so ``_stream_agent_events`` still # suppresses the planner's intermediate events from the UI. planner = self._build_kb_planner_runnable() try: diff --git a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py index 8628479c7..503c73ccc 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py +++ b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py @@ -1,18 +1,23 @@ """ ``_noop`` provider-compatibility tool + injection middleware. -OpenCode injects a ``_noop`` tool for LiteLLM/Bedrock/Copilot when the -model call has empty tools but message history includes prior -``tool_calls`` — some providers 400 in that shape (see -``opencode/packages/opencode/src/session/llm.ts:209-228``). SurfSense uses -LiteLLM, and the compaction summarize call (no tools, history full of -tool calls) hits this. Tier 1.5 in the OpenCode-port plan. +Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has +empty ``tools`` but the message history includes prior ``tool_calls`` — +they treat that shape as malformed even though it's perfectly valid +LangChain. SurfSense hits this on the compaction summarize call (no +tools, history full of tool calls). + +Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``, +which discovered and codified the workaround: inject a no-op tool *only* +on those provider shapes so the request validates without ever being +called. Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks if the request has zero tools but the last AI message in history includes -``tool_calls``. If yes, it injects the ``_noop`` tool only — never globally, -mirroring opencode's gating exactly. The :func:`noop_tool` returns empty -content when called (which it should never be in practice). +``tool_calls``. If yes, it injects the ``_noop`` tool only — never +globally — mirroring OpenCode's gating exactly. The :func:`noop_tool` +returns empty content when called (which it should never be in +practice). """ from __future__ import annotations @@ -45,8 +50,9 @@ def noop_tool() -> str: # Provider markers that benefit from ``_noop`` injection. These match -# opencode's gating list. We also accept any string containing one of -# these substrings (so e.g. ``litellm`` matches ``ChatLiteLLM``). +# OpenCode's gating list (``llm.ts:209-228``). We also accept any string +# containing one of these substrings so e.g. ``litellm`` matches +# ``ChatLiteLLM``. _NOOP_NEEDED_PROVIDERS: tuple[str, ...] = ( "litellm", "bedrock", diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py index f51d2f7bb..cfe1edae4 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py +++ b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py @@ -3,14 +3,14 @@ OpenTelemetry span middleware for the SurfSense ``new_chat`` agent. Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool executions) with OTel spans, attaching low-cardinality span names and -high-cardinality identifiers as attributes (per the Tier 3b plan). +high-cardinality identifiers as attributes. This middleware is intentionally a thin adapter over :mod:`app.observability.otel`; when OTel is not configured all spans collapse to no-ops and the wrapper adds <1µs overhead per call. When OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every -model and tool call gets a span with the standard attributes the -plan's dashboards expect. +model and tool call gets a span with the standard attributes our +dashboards expect. """ from __future__ import annotations diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py index 6e1f42baf..37719e96a 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/permission.py +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -1,10 +1,15 @@ """ PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback. -Mirrors ``opencode/packages/opencode/src/permission/index.ts`` but uses -SurfSense's existing ``interrupt({type, action, context})`` payload shape -(see ``app/agents/new_chat/tools/hitl.py``) so the frontend keeps -working unchanged. Tier 2.1 in the OpenCode-port plan. +LangChain's :class:`HumanInTheLoopMiddleware` only supports a static +"this tool always asks" decision per tool. There's no rule-based +allow/deny/ask layered ruleset, no glob patterns, no per-search-space or +per-thread overrides, and no auto-deny synthesis. + +This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts`` +ruleset model on top of SurfSense's existing ``interrupt({type, action, +context})`` payload shape (see ``app/agents/new_chat/tools/hitl.py``) so +the frontend keeps working unchanged. Operation: 1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``. @@ -24,9 +29,9 @@ Operation: The middleware also performs a *pre-model* tool-filter step (the ``before_model`` hook) so globally denied tools are stripped from the -exposed tool list before the model gets to see them. This is -opencode's ``Permission.disabled`` equivalent and dramatically reduces -the chance the model emits a deny-only call. +exposed tool list before the model gets to see them. This mirrors +OpenCode's ``Permission.disabled`` and dramatically reduces the chance +the model emits a deny-only call. """ from __future__ import annotations @@ -117,7 +122,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] self._emit_interrupt = always_emit_interrupt_payload # ------------------------------------------------------------------ - # Tool-filter step (opencode `Permission.disabled` equivalent) + # Tool-filter step (mirrors OpenCode's ``Permission.disabled``) # ------------------------------------------------------------------ def _globally_denied(self, tool_name: str) -> bool: @@ -197,8 +202,8 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] "always": patterns, }, } - # Tier 3b: permission.asked + interrupt.raised spans (no-op when - # OTel is disabled). Both fire here so dashboards can correlate + # Open ``permission.asked`` + ``interrupt.raised`` OTel spans + # (no-op when OTel is disabled) so dashboards can correlate # "we asked X" with "interrupt was actually delivered". with ( ot.permission_asked_span( diff --git a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py index 394bb0371..0c3d3d017 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py +++ b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py @@ -1,10 +1,16 @@ """ RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing. -Why standalone instead of subclassing ``ModelRetryMiddleware``: the upstream -class calls module-level ``calculate_delay`` inline (no overridable -``_calculate_delay`` hook), so a subclass cannot inject Retry-After header -delays without rewriting the loop. Tier 1.4 in the OpenCode-port plan. +LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores +the ``Retry-After`` HTTP header — it just runs its own exponential backoff. +That wastes time when a provider has explicitly told us how long to wait. +This middleware honors the header (mirroring OpenCode's +``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE +event so the UI can show "rate-limited, retrying in Ns". + +We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a +module-level ``calculate_delay`` inline (no overridable +``_calculate_delay`` hook), so this is a standalone implementation. Behaviour: - Extracts ``Retry-After`` / ``retry-after-ms`` from diff --git a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py index 54df0cc60..9f81a168b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py +++ b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py @@ -1,10 +1,6 @@ """ ToolCallNameRepairMiddleware — two-stage tool-name repair. -Mirrors ``opencode/packages/opencode/src/session/llm.ts:339-358`` plus -``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.7 in the -OpenCode-port plan. - Operation: 1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in the registry but ``name.lower()`` is, rewrite in place. Catches @@ -14,9 +10,13 @@ Operation: so the registered :func:`invalid_tool` returns the error to the model for self-correction. -Distinct from :class:`deepagents.middleware.PatchToolCallsMiddleware`, -which patches *dangling* tool calls (no matching ToolMessage) — that -class does not handle the wrong-name case at all. +Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358`` ++ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent: +:class:`deepagents.middleware.PatchToolCallsMiddleware` patches +*dangling* tool calls (no matching ToolMessage) but does nothing about +wrong names, and the model framework's default behavior on an unknown +name is to crash the turn rather than route to a self-correction +fallback. """ from __future__ import annotations @@ -61,7 +61,8 @@ class ToolCallNameRepairMiddleware( ``invalid`` should be in this set so the fallback dispatches. fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the fuzzy-match step that runs *between* lowercase and invalid. - Set to ``None`` to disable fuzzy matching (opencode parity). + Set to ``None`` to disable fuzzy matching (default in + OpenCode; we mirror that to avoid silent rewrites). """ def __init__( @@ -106,7 +107,7 @@ class ToolCallNameRepairMiddleware( call["response_metadata"] = metadata return call - # Optional fuzzy step (off by default for opencode parity) + # Optional fuzzy step (off by default — see class docstring) if self._fuzzy_threshold is not None: close = difflib.get_close_matches( name, registered, n=1, cutoff=self._fuzzy_threshold diff --git a/surfsense_backend/app/agents/new_chat/permissions.py b/surfsense_backend/app/agents/new_chat/permissions.py index 50a0cfbdc..523deb11f 100644 --- a/surfsense_backend/app/agents/new_chat/permissions.py +++ b/surfsense_backend/app/agents/new_chat/permissions.py @@ -1,21 +1,20 @@ """ Wildcard pattern matching + rule evaluation for the SurfSense permission system. -Mirrors ``opencode/packages/opencode/src/permission/evaluate.ts`` and -``opencode/packages/opencode/src/util/wildcard.ts`` precisely: +Ported from OpenCode's ``packages/opencode/src/permission/evaluate.ts`` and +``packages/opencode/src/util/wildcard.ts``. LangChain has no rule-based +permission evaluator, so we keep OpenCode's semantics intact: - ``Wildcard.match`` matches both the ``permission`` and the ``pattern`` fields of a rule against the requested ``(permission, pattern)`` pair. ``*`` matches any segment, ``**`` matches across separators. - The evaluator runs ``findLast`` over the **flattened** list of rules from all rulesets — last matching rule wins. -- The default fallback is ``ask`` (NOT deny), matching opencode. +- The default fallback is ``ask`` (NOT deny), matching OpenCode. - Multi-pattern requests AND together: if ANY pattern resolves to ``deny``, the whole request is denied; if ANY needs ``ask``, an interrupt is raised; only when all patterns ``allow`` does the request proceed. - -Tier 2.1 in the OpenCode-port plan. """ from __future__ import annotations diff --git a/surfsense_backend/app/agents/new_chat/plugin_loader.py b/surfsense_backend/app/agents/new_chat/plugin_loader.py index 426e28041..c52620d40 100644 --- a/surfsense_backend/app/agents/new_chat/plugin_loader.py +++ b/surfsense_backend/app/agents/new_chat/plugin_loader.py @@ -1,9 +1,10 @@ """Entry-point based plugin loader for SurfSense agent middleware. -The realization in the Tier 6 plan: LangChain's :class:`AgentMiddleware` ABC -already covers the practical surface most plugins need (``before_agent`` / -``before_model`` / ``wrap_tool_call`` / their async counterparts), so a -SurfSense-specific plugin protocol is unnecessary. +LangChain's :class:`AgentMiddleware` ABC already covers the practical +surface most plugins need (``before_agent`` / ``before_model`` / +``wrap_tool_call`` / their async counterparts), so a SurfSense-specific +plugin protocol would be redundant. We just need a way to discover and +admit third-party middleware safely. A plugin is therefore just an installable Python package that registers a factory callable under the ``surfsense.plugins`` entry-point group: diff --git a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py index 3e2e631d2..2b7781b90 100644 --- a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py +++ b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py @@ -1,10 +1,10 @@ """Reference plugin: substitute ``{{year}}`` in tool descriptions. -Mirrors the OpenCode ``chat.system.transform`` example. Demonstrates the -:meth:`AgentMiddleware.awrap_tool_call` hook -- the plugin sees every tool -invocation and can rewrite the request *or* the result. This particular -plugin is read-only and only transforms the *description* the user might -see in error messages (no request mutation). +Demonstrates the :meth:`AgentMiddleware.awrap_tool_call` hook -- the +plugin sees every tool invocation and can rewrite the request *or* the +result. This particular plugin is read-only and only transforms the +*description* the user might see in error messages (no request +mutation). The plugin is built as a factory function so the entry-point loader can inject :class:`PluginContext` (containing the agent's LLM, search-space diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/agents/new_chat/prompts/composer.py index 77b86aeef..42f8303e6 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/composer.py +++ b/surfsense_backend/app/agents/new_chat/prompts/composer.py @@ -14,7 +14,13 @@ under :mod:`app.agents.new_chat.prompts`. It replaces the monolithic examples/ # one ``.md`` per tool with call examples routing/ # connector-specific routing notes (linear, slack, …) -Tier 3a in the OpenCode-port plan. +The model-family dispatch step (see :func:`detect_provider_variant`) +mirrors OpenCode's ``packages/opencode/src/session/system.ts`` — different +model families respond best to differently-styled prompts (Claude likes +XML/narrative, GPT-5 wants channel-aware pragmatic, Codex needs +terse/file:line, Gemini wants formal numbered steps, etc.). LangChain's +``dynamic_prompt`` helper supports per-call prompt swaps but ships no +out-of-the-box family classifier, so we keep our own. Backwards compatibility ======================= @@ -42,10 +48,11 @@ from app.db import ChatVisibility # When adding a new variant, also drop a matching ``providers/.md`` # file in this package and (if appropriate) extend the regex matchers below. # -# Stylistic clusters mirror OpenCode's prompt-per-family layout but adapted -# to SurfSense's "supplemental hints" architecture (each fragment is a -# focused style nudge, NOT a full system prompt — the main prompt is -# already assembled from base/ + tools/ + routing/). +# Stylistic clusters: each variant is a focused style nudge, NOT a full +# system prompt — the main prompt is already assembled from base/ + +# tools/ + routing/. The clustering itself (which models map to which +# style) follows OpenCode's ``system.ts`` family table; see the module +# docstring for credits. ProviderVariant = str # Known values: # "anthropic" — Claude family (XML-friendly, narrative todos) @@ -82,8 +89,8 @@ def detect_provider_variant(model_name: str | None) -> ProviderVariant: Order is significant: more-specific patterns are tried first so ``gpt-5-codex`` routes to ``"openai_codex"`` rather than - ``"openai_reasoning"`` (mirrors OpenCode's - ``packages/opencode/src/session/system.ts`` dispatch). + ``"openai_reasoning"`` — same dispatch order as OpenCode's + ``packages/opencode/src/session/system.ts``. """ if not model_name: return "default" diff --git a/surfsense_backend/app/agents/new_chat/subagents/__init__.py b/surfsense_backend/app/agents/new_chat/subagents/__init__.py index b9f21a0d2..7d678ec79 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/__init__.py +++ b/surfsense_backend/app/agents/new_chat/subagents/__init__.py @@ -1,14 +1,17 @@ """Specialized user-facing subagents for the SurfSense agent. -Each subagent is a :class:`deepagents.SubAgent` typed-dict spec passed to -:class:`deepagents.SubAgentMiddleware`, which materializes them as ephemeral -runnables invoked via the ``task`` tool. +The :class:`deepagents.SubAgentMiddleware` already provides the +materialization machinery (each :class:`deepagents.SubAgent` typed-dict +spec is compiled into an ephemeral runnable invoked via the ``task`` +tool); what's specific to SurfSense is the *seeding* of those subagents +with declarative deny rules. Per-subagent permission rules are injected as a :class:`PermissionMiddleware` entry inside the subagent's ``middleware`` -field, mirroring opencode ``tool/task.ts`` which seeds child sessions with -deny rules for tools the parent does not want them touching (e.g. -``task``/``todowrite`` recursion, write tools for read-only research roles). +field. The auto-deny pattern (e.g. forbid ``task``/``todowrite`` +recursion, block write tools for read-only research roles) is borrowed +from OpenCode's ``packages/opencode/src/tool/task.ts``, which has +analogous logic for restricting child sessions. """ from .config import ( diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 3919527d9..56f838d7e 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -1,13 +1,14 @@ """ Thin compatibility wrapper around :mod:`app.agents.new_chat.prompts.composer`. -Tier 3a of the OpenCode-port plan replaced the monolithic prompt strings -in this module with a fragment tree under ``prompts/`` and a composer -function. This module preserves the public function surface -(``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` / -``get_default_system_instructions`` / ``SURFSENSE_SYSTEM_PROMPT``) so that -existing call sites — `chat_deepagent.py`, anonymous chat routes, and the -configurable-prompt admin path — keep working without churn. +The composer split the previous monolithic prompt string into a fragment +tree under ``prompts/`` plus a model-family dispatch step (see the +composer module docstring for credits). This module preserves the public +function surface (``build_surfsense_system_prompt`` / +``build_configurable_system_prompt`` / +``get_default_system_instructions`` / ``SURFSENSE_SYSTEM_PROMPT``) so +that existing call sites — `chat_deepagent.py`, anonymous chat routes, +and the configurable-prompt admin path — keep working without churn. For new call sites prefer importing ``compose_system_prompt`` directly from :mod:`app.agents.new_chat.prompts.composer`. diff --git a/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py b/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py index df10fcbe3..ea4bc0bc1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py @@ -6,8 +6,9 @@ tool, :class:`ToolCallNameRepairMiddleware` rewrites the call to ``invalid`` with the original name and a parser/validation error string. This tool's execution then returns that error to the model so it can self-correct. -Mirrors ``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.6 in -the OpenCode-port plan. +Ported from OpenCode's ``packages/opencode/src/tool/invalid.ts`` — +LangChain has no equivalent fallback path; the default behavior on an +unknown tool name is a hard ``ToolNotFoundError`` which kills the turn. Critically, the :class:`ToolDefinition` for this tool is **excluded** from the system-prompt tool list and from ``LLMToolSelectorMiddleware`` selection diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index fce1bf872..e8bab36fd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -132,12 +132,10 @@ class ToolDefinition: that must be in ``available_connectors`` for the tool to be enabled. dedup_key: Optional callable that maps a tool's ``args`` dict to a string signature used by :class:`DedupHITLToolCallsMiddleware` - to drop duplicate calls. Replaces the legacy hardcoded - ``_NATIVE_HITL_TOOL_DEDUP_KEYS`` map (Tier 2.3 in the - OpenCode-port plan). + to drop duplicate calls within a single LLM response. reverse: Optional callable that, given the tool's ``(args, result)``, returns a ``ReverseDescriptor`` describing the inverse tool - invocation. Consumed by the snapshot/revert pipeline (Tier 5). + invocation. Consumed by the snapshot/revert pipeline. """ diff --git a/surfsense_backend/app/observability/otel.py b/surfsense_backend/app/observability/otel.py index 4f2257ab7..6791ab499 100644 --- a/surfsense_backend/app/observability/otel.py +++ b/surfsense_backend/app/observability/otel.py @@ -1,12 +1,10 @@ """ OpenTelemetry instrumentation helpers for the SurfSense agent stack. -Tier 3b in the OpenCode-port plan. - Goals ===== -- Provide one tiny, ergonomic API for the spans listed in the plan +- Provide one tiny, ergonomic API for the spans we care about (``tool.call``, ``model.call``, ``kb.search``, ``kb.persist``, ``compaction.run``, ``interrupt.raised``, ``permission.asked``). - Keep span **names** low-cardinality (``tool.call`` rather than diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index cbe4e7417..12484ff53 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -1,9 +1,9 @@ """POST ``/api/threads/{thread_id}/revert/{action_id}``: undo an agent action. -Per the Tier 5 plan, the route ships **before** the UI lights up the per-message -"Undo from here" affordance. To prevent accidental usage during the gap we -return ``503 Service Unavailable`` until the -``SURFSENSE_ENABLE_REVERT_ROUTE`` flag flips. Once enabled, the route runs: +The route ships **before** the UI lights up the per-message "Undo from +here" affordance. To prevent accidental usage during the gap we return +``503 Service Unavailable`` until the ``SURFSENSE_ENABLE_REVERT_ROUTE`` +flag flips. Once enabled, the route runs: 1. Authentication via :func:`current_active_user`. 2. Action lookup; 404 if the action does not belong to the thread. diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index aa0c215b9..397b1c787 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -1,4 +1,4 @@ -"""Tests for the prompt fragment composer (Tier 3a).""" +"""Tests for the prompt fragment composer.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py index e5b171612..55434c04d 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py @@ -1,4 +1,4 @@ -"""Tests for the OtelSpanMiddleware adapter (Tier 3b).""" +"""Tests for the OtelSpanMiddleware adapter.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py index 4924f2aee..8ec16617a 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py @@ -1,4 +1,4 @@ -"""Tests for the wildcard matcher and rule evaluator (opencode evaluate.ts parity).""" +"""Tests for the wildcard matcher and rule evaluator (parity with OpenCode evaluate.ts).""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py index c2118c697..5dbf765a7 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py @@ -1,4 +1,4 @@ -"""Unit tests for the SurfSense plugin entry-point loader (Tier 6).""" +"""Unit tests for the SurfSense plugin entry-point loader.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/observability/test_otel.py b/surfsense_backend/tests/unit/observability/test_otel.py index 583142098..fc5813973 100644 --- a/surfsense_backend/tests/unit/observability/test_otel.py +++ b/surfsense_backend/tests/unit/observability/test_otel.py @@ -1,4 +1,4 @@ -"""Tests for the SurfSense OpenTelemetry shim (Tier 3b).""" +"""Tests for the SurfSense OpenTelemetry shim.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/services/test_revert_service.py b/surfsense_backend/tests/unit/services/test_revert_service.py index e2cbe383a..a81e52041 100644 --- a/surfsense_backend/tests/unit/services/test_revert_service.py +++ b/surfsense_backend/tests/unit/services/test_revert_service.py @@ -1,4 +1,4 @@ -"""Unit tests for the agent revert service (Tier 5.3).""" +"""Unit tests for the agent revert service.""" from __future__ import annotations From 57db198919bbd1e7da8d8364aa90eba01525e7d0 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:14:56 +0530 Subject: [PATCH 226/299] feat(chat): add thread-level auto model pinning fields --- ...34_add_thread_auto_model_pinning_fields.py | 63 +++++++++++++++++++ surfsense_backend/app/db.py | 7 +++ 2 files changed, 70 insertions(+) create mode 100644 surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py diff --git a/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py new file mode 100644 index 000000000..ab1643b02 --- /dev/null +++ b/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py @@ -0,0 +1,63 @@ +"""134_add_thread_auto_model_pinning_fields + +Revision ID: 134 +Revises: 133 +Create Date: 2026-04-29 + +Add thread-level fields to persist Auto (Fastest) model pinning metadata: +- pinned_llm_config_id: concrete resolved config id used for this thread +- pinned_auto_mode: auto policy identifier (currently "auto_fastest") +- pinned_at: timestamp when the pin was created/refreshed +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "134" +down_revision: str | None = "133" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "new_chat_threads", + sa.Column("pinned_llm_config_id", sa.Integer(), nullable=True), + ) + op.add_column( + "new_chat_threads", + sa.Column("pinned_auto_mode", sa.String(length=32), nullable=True), + ) + op.add_column( + "new_chat_threads", + sa.Column("pinned_at", sa.TIMESTAMP(timezone=True), nullable=True), + ) + + op.create_index( + "ix_new_chat_threads_pinned_llm_config_id", + "new_chat_threads", + ["pinned_llm_config_id"], + unique=False, + ) + op.create_index( + "ix_new_chat_threads_pinned_auto_mode", + "new_chat_threads", + ["pinned_auto_mode"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("ix_new_chat_threads_pinned_auto_mode", table_name="new_chat_threads") + op.drop_index( + "ix_new_chat_threads_pinned_llm_config_id", table_name="new_chat_threads" + ) + + op.drop_column("new_chat_threads", "pinned_at") + op.drop_column("new_chat_threads", "pinned_auto_mode") + op.drop_column("new_chat_threads", "pinned_llm_config_id") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 75342a8e1..f8b1390d9 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,6 +638,13 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) + # Auto model pinning metadata: + # - pinned_llm_config_id stores the concrete resolved model config id. + # - pinned_auto_mode indicates which auto policy produced the pin. + # This allows Auto (Fastest) to resolve once per thread and stay stable. + pinned_llm_config_id = Column(Integer, nullable=True, index=True) + pinned_auto_mode = Column(String(32), nullable=True, index=True) + pinned_at = Column(TIMESTAMP(timezone=True), nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") From 41849fe10f5fbe9a4792ad665308f5fea4c37721 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:15 +0530 Subject: [PATCH 227/299] feat(chat): add auto model pin resolution service --- .../app/services/auto_model_pin_service.py | 205 ++++++++++++ .../services/test_auto_model_pin_service.py | 291 ++++++++++++++++++ 2 files changed, 496 insertions(+) create mode 100644 surfsense_backend/app/services/auto_model_pin_service.py create mode 100644 surfsense_backend/tests/unit/services/test_auto_model_pin_service.py diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py new file mode 100644 index 000000000..ce417a26d --- /dev/null +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -0,0 +1,205 @@ +"""Resolve and persist Auto (Fastest) model pins per chat thread. + +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and +persist the chosen config id on ``new_chat_threads`` so subsequent turns are +stable. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import NewChatThread +from app.services.token_quota_service import TokenQuotaService + +logger = logging.getLogger(__name__) + +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" + + +@dataclass +class AutoPinResolution: + resolved_llm_config_id: int + resolved_tier: str + from_existing_pin: bool + + +def _is_usable_global_config(cfg: dict) -> bool: + return bool( + cfg.get("id") is not None + and cfg.get("model_name") + and cfg.get("provider") + and cfg.get("api_key") + ) + + +def _global_candidates() -> list[dict]: + candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +def _tier_of(cfg: dict) -> str: + return str(cfg.get("billing_tier", "free")).lower() + + +def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + idx = int.from_bytes(digest[:8], "big") % len(candidates) + return candidates[idx] + + +def _to_uuid(user_id: str | UUID | None) -> UUID | None: + if user_id is None: + return None + if isinstance(user_id, UUID): + return user_id + try: + return UUID(str(user_id)) + except Exception: + return None + + +async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool: + parsed = _to_uuid(user_id) + if parsed is None: + return False + usage = await TokenQuotaService.premium_get_usage(session, parsed) + return bool(usage.allowed) + + +async def resolve_or_get_pinned_llm_config_id( + session: AsyncSession, + *, + thread_id: int, + search_space_id: int, + user_id: str | UUID | None, + selected_llm_config_id: int, +) -> AutoPinResolution: + """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + + For non-auto selections, this function clears existing auto pin metadata and + returns the selected id as-is. + """ + thread = ( + ( + await session.execute( + select(NewChatThread) + .where(NewChatThread.id == thread_id) + .with_for_update(of=NewChatThread) + ) + ) + .unique() + .scalar_one_or_none() + ) + if thread is None: + raise ValueError(f"Thread {thread_id} not found") + if thread.search_space_id != search_space_id: + raise ValueError( + f"Thread {thread_id} does not belong to search space {search_space_id}" + ) + + # Explicit model selected: clear stale auto pin metadata. + if selected_llm_config_id != AUTO_FASTEST_ID: + if ( + thread.pinned_llm_config_id is not None + or thread.pinned_auto_mode is not None + or thread.pinned_at is not None + ): + thread.pinned_llm_config_id = None + thread.pinned_auto_mode = None + thread.pinned_at = None + await session.commit() + return AutoPinResolution( + resolved_llm_config_id=selected_llm_config_id, + resolved_tier="explicit", + from_existing_pin=False, + ) + + candidates = _global_candidates() + if not candidates: + raise ValueError("No usable global LLM configs are available for Auto mode") + candidate_by_id = {int(c["id"]): c for c in candidates} + + # Reuse existing valid pin without re-checking current quota (no silent tier switch). + pinned_id = thread.pinned_llm_config_id + if ( + thread.pinned_auto_mode == AUTO_FASTEST_MODE + and pinned_id is not None + and int(pinned_id) in candidate_by_id + ): + pinned_cfg = candidate_by_id[int(pinned_id)] + logger.info( + "auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s", + thread_id, + search_space_id, + pinned_id, + _tier_of(pinned_cfg), + ) + return AutoPinResolution( + resolved_llm_config_id=int(pinned_id), + resolved_tier=_tier_of(pinned_cfg), + from_existing_pin=True, + ) + if pinned_id is not None: + logger.info( + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + thread_id, + search_space_id, + pinned_id, + thread.pinned_auto_mode, + ) + + premium_eligible = await _is_premium_eligible(session, user_id) + if premium_eligible: + eligible = candidates + else: + eligible = [c for c in candidates if _tier_of(c) != "premium"] + + if not eligible: + raise ValueError( + "Auto mode could not find an eligible LLM config for this user and quota state" + ) + + selected_cfg = _deterministic_pick(eligible, thread_id) + selected_id = int(selected_cfg["id"]) + selected_tier = _tier_of(selected_cfg) + + thread.pinned_llm_config_id = selected_id + thread.pinned_auto_mode = AUTO_FASTEST_MODE + thread.pinned_at = datetime.now(UTC) + await session.commit() + + if pinned_id is None: + logger.info( + "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + selected_id, + selected_tier, + premium_eligible, + ) + else: + logger.info( + "auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + selected_tier, + premium_eligible, + ) + return AutoPinResolution( + resolved_llm_config_id=selected_id, + resolved_tier=selected_tier, + from_existing_pin=False, + ) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py new file mode 100644 index 000000000..a9853c980 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + AUTO_FASTEST_MODE, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread( + *, + search_space_id: int = 10, + pinned_llm_config_id: int | None = None, + pinned_auto_mode: str | None = None, +): + return SimpleNamespace( + id=1, + search_space_id=search_space_id, + pinned_llm_config_id=pinned_llm_config_id, + pinned_auto_mode=pinned_auto_mode, + pinned_at=None, + ) + + +@pytest.mark.asyncio +async def test_auto_first_turn_pins_one_model(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id in {-1, -2} + assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id + assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE + assert session.thread.pinned_at is not None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_next_turn_reuses_existing_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not be called for valid pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_can_pin_premium(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_ineligible_auto_pins_free_only(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + + +@pytest.mark.asyncio +async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_explicit_user_model_change_clears_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=7, + ) + assert result.resolved_llm_config_id == 7 + assert session.thread.pinned_llm_config_id is None + assert session.thread.pinned_auto_mode is None + assert session.thread.pinned_at is None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert session.thread.pinned_llm_config_id == -2 + assert session.commit_count == 1 From 835bd9f65df2abfd80ecd8def501b2db8595c326 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:36 +0530 Subject: [PATCH 228/299] fix(chat): enforce pinned model quota flow and reset stale pins --- .../app/routes/search_spaces_routes.py | 25 +++- .../app/tasks/chat/stream_new_chat.py | 107 +++++++++++------- 2 files changed, 88 insertions(+), 44 deletions(-) diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 828137518..7944e7d66 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage from pydantic import BaseModel as PydanticBaseModel -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem from app.config import config from app.db import ( ImageGenerationConfig, + NewChatThread, NewLLMConfig, Permission, SearchSpace, @@ -790,9 +791,31 @@ async def update_llm_preferences( # Update preferences update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id for key, value in update_data.items(): setattr(search_space, key, value) + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values( + pinned_llm_config_id=None, + pinned_auto_mode=None, + pinned_at=None, + ) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + await session.commit() await session.refresh(search_space) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c254e66e2..1a56547ca 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -56,6 +56,7 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT +from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -1456,6 +1457,21 @@ async def stream_new_chat( agent_config: AgentConfig | None = None _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + if llm_config_id >= 0: # Positive ID: Load from NewLLMConfig database table agent_config = await load_agent_config( @@ -1491,12 +1507,11 @@ async def stream_new_chat( llm_config_id, ) - # Premium quota reservation — applies to explicitly premium configs - # AND Auto mode (which may route to premium models). + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid @@ -1519,16 +1534,18 @@ async def stream_new_chat( ) _premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." - ) - yield streaming_service.format_done() - return - # Auto mode: quota exhausted but we can still proceed - # (the router may pick a free model). Reset reservation. - _premium_request_id = None - _premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + yield streaming_service.format_error( + "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -1961,28 +1978,20 @@ async def stream_new_chat( ) # Finalize premium quota with actual tokens. - # For Auto mode, only count tokens from calls that used premium models. if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_premium_reserved, ) + _premium_request_id = None + _premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -2175,6 +2184,21 @@ async def stream_resume_chat( agent_config: AgentConfig | None = None _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + if llm_config_id >= 0: agent_config = await load_agent_config( session=session, @@ -2208,7 +2232,7 @@ async def stream_resume_chat( _resume_needs_premium = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid @@ -2231,14 +2255,18 @@ async def stream_resume_chat( ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." - ) - yield streaming_service.format_done() - return - _resume_premium_request_id = None - _resume_premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + yield streaming_service.format_error( + "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -2370,23 +2398,16 @@ async def stream_resume_chat( try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_resume_premium_reserved, ) + _resume_premium_request_id = None + _resume_premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", From d5ef0d2598573578d3abf0140c58da6d4e63401d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:46 +0530 Subject: [PATCH 229/299] feat(ui): surface pinned premium quota alerts in chat thread --- .../new-chat/[[...chat_id]]/page.tsx | 81 +++++++++++++++++-- .../atoms/chat/premium-alert.atom.ts | 33 ++++++++ .../components/assistant-ui/thread.tsx | 44 +++++++++- 3 files changed, 148 insertions(+), 10 deletions(-) create mode 100644 surfsense_web/atoms/chat/premium-alert.atom.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 7773a438a..a5461e17f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -19,6 +19,7 @@ import { currentThreadAtom, setTargetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; +import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { type MentionedDocumentInfo, mentionedDocumentIdsAtom, @@ -200,6 +201,19 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); +const PINNED_PREMIUM_QUOTA_MESSAGE = "Premium token quota exceeded for this pinned model."; + +function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { + if (!(error instanceof Error)) return null; + if (!error.message.toLowerCase().includes("premium token quota exceeded")) { + return null; + } + if (!error.message.toLowerCase().includes("pinned model")) { + return null; + } + return error.message || PINNED_PREMIUM_QUOTA_MESSAGE; +} + export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -226,6 +240,7 @@ export default function NewChatPage() { const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setCurrentThreadState = useSetAtom(currentThreadAtom); + const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom); const closeReportPanel = useSetAtom(closeReportPanelAtom); @@ -951,6 +966,7 @@ export default function NewChatPage() { return; } console.error("[NewChatPage] Chat error:", error); + const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); // Track chat error trackChatError( @@ -959,7 +975,15 @@ export default function NewChatPage() { error instanceof Error ? error.message : "Unknown error" ); - toast.error("Failed to get response. Please try again."); + if (premiumQuotaAlertMessage) { + setPremiumAlertForThread({ + threadId: currentThreadId, + message: premiumQuotaAlertMessage, + }); + toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); + } else { + toast.error("Failed to get response. Please try again."); + } // Update assistant message with error setMessages((prev) => prev.map((m) => @@ -969,7 +993,9 @@ export default function NewChatPage() { content: [ { type: "text", - text: "Sorry, there was an error. Please try again.", + text: + premiumQuotaAlertMessage ?? + "Sorry, there was an error. Please try again.", }, ], } @@ -998,6 +1024,7 @@ export default function NewChatPage() { pendingUserImageUrls, setPendingUserImageUrls, toolsWithUI, + setPremiumAlertForThread, ] ); @@ -1257,13 +1284,29 @@ export default function NewChatPage() { return; } console.error("[NewChatPage] Resume error:", error); - toast.error("Failed to resume. Please try again."); + const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + setPremiumAlertForThread({ + threadId: resumeThreadId, + message: premiumQuotaAlertMessage, + }); + toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); + } else { + toast.error("Failed to resume. Please try again."); + } } finally { setIsRunning(false); abortControllerRef.current = null; } }, - [pendingInterrupt, messages, searchSpaceId, tokenUsageStore, toolsWithUI] + [ + pendingInterrupt, + messages, + searchSpaceId, + tokenUsageStore, + toolsWithUI, + setPremiumAlertForThread, + ] ); useEffect(() => { @@ -1584,18 +1627,34 @@ export default function NewChatPage() { } batcher.dispose(); console.error("[NewChatPage] Regeneration error:", error); + const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); trackChatError( searchSpaceId, threadId, error instanceof Error ? error.message : "Unknown error" ); - toast.error("Failed to regenerate response. Please try again."); + if (premiumQuotaAlertMessage) { + setPremiumAlertForThread({ + threadId, + message: premiumQuotaAlertMessage, + }); + toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); + } else { + toast.error("Failed to regenerate response. Please try again."); + } setMessages((prev) => prev.map((m) => m.id === assistantMsgId ? { ...m, - content: [{ type: "text", text: "Sorry, there was an error. Please try again." }], + content: [ + { + type: "text", + text: + premiumQuotaAlertMessage ?? + "Sorry, there was an error. Please try again.", + }, + ], } : m ) @@ -1605,7 +1664,15 @@ export default function NewChatPage() { abortControllerRef.current = null; } }, - [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] + [ + threadId, + searchSpaceId, + messages, + disabledTools, + tokenUsageStore, + toolsWithUI, + setPremiumAlertForThread, + ] ); // Handle editing a message - truncates history and regenerates with new query diff --git a/surfsense_web/atoms/chat/premium-alert.atom.ts b/surfsense_web/atoms/chat/premium-alert.atom.ts new file mode 100644 index 000000000..c0efc174f --- /dev/null +++ b/surfsense_web/atoms/chat/premium-alert.atom.ts @@ -0,0 +1,33 @@ +import { atom } from "jotai"; + +export type PremiumAlertState = { + message: string; +}; + +export const premiumAlertByThreadAtom = atom>({}); + +export const setPremiumAlertForThreadAtom = atom( + null, + ( + get, + set, + payload: { + threadId: number; + message: string; + } + ) => { + const current = get(premiumAlertByThreadAtom); + set(premiumAlertByThreadAtom, { + ...current, + [payload.threadId]: { message: payload.message }, + }); + } +); + +export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => { + const current = get(premiumAlertByThreadAtom); + if (!(threadId in current)) return; + const next = { ...current }; + delete next[threadId]; + set(premiumAlertByThreadAtom, next); +}); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cf99598f1..06f25f5fb 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -37,10 +37,13 @@ import { toggleToolAtom, } from "@/atoms/agent-tools/agent-tools.atoms"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; -import { - mentionedDocumentsAtom, -} from "@/atoms/chat/mentioned-documents.atom"; +import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; +import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; +import { + clearPremiumAlertForThreadAtom, + premiumAlertByThreadAtom, +} from "@/atoms/chat/premium-alert.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; @@ -134,6 +137,9 @@ const ThreadContent: FC = () => { style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} > + !thread.isEmpty}> + + !thread.isEmpty}> @@ -143,6 +149,38 @@ const ThreadContent: FC = () => { ); }; +const PremiumQuotaPinnedAlert: FC = () => { + const currentThreadState = useAtomValue(currentThreadAtom); + const alertsByThread = useAtomValue(premiumAlertByThreadAtom); + const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom); + + const currentThreadId = currentThreadState?.id; + if (!currentThreadId) return null; + + const alert = alertsByThread[currentThreadId]; + if (!alert) return null; + + return ( +
+
+ +
+

Premium quota exhausted

+

{alert.message}

+
+ +
+
+ ); +}; + const ThreadScrollToBottom: FC = () => { return ( From c110f5b9551d8593a2e9a207282153c30003da86 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Wed, 29 Apr 2026 07:20:31 -0700 Subject: [PATCH 230/299] feat: improved agent streaming --- surfsense_backend/.env.example | 8 + .../versions/134_relax_revision_fks.py | 139 +++ .../135_action_log_correlation_ids.py | 82 ++ .../versions/136_new_chat_message_turn_id.py | 52 ++ .../137_unique_reverse_of_in_action_log.py | 74 ++ .../app/agents/new_chat/chat_deepagent.py | 72 +- .../app/agents/new_chat/feature_flags.py | 14 + .../app/agents/new_chat/filesystem_state.py | 69 +- .../agents/new_chat/middleware/action_log.py | 77 +- .../agents/new_chat/middleware/filesystem.py | 421 ++++++++- .../new_chat/middleware/kb_persistence.py | 870 +++++++++++++++++- .../middleware/kb_postgres_backend.py | 109 ++- .../new_chat/middleware/knowledge_tree.py | 40 +- .../middleware/local_folder_backend.py | 68 ++ .../multi_root_local_folder_backend.py | 28 + .../app/agents/new_chat/state_reducers.py | 4 + .../app/agents/new_chat/subagents/config.py | 2 + .../app/agents/new_chat/tools/hitl.py | 42 + surfsense_backend/app/db.py | 36 +- .../app/routes/agent_action_log_route.py | 9 + .../app/routes/agent_revert_route.py | 386 +++++++- .../app/routes/new_chat_routes.py | 486 +++++++++- surfsense_backend/app/schemas/new_chat.py | 44 + .../app/services/new_streaming_service.py | 86 +- .../app/services/revert_service.py | 440 ++++++++- .../app/tasks/chat/stream_new_chat.py | 324 ++++++- .../unit/agents/new_chat/test_action_log.py | 110 +++ .../new_chat/test_desktop_safety_rules.py | 122 +++ .../agents/new_chat/test_hitl_auto_approve.py | 111 +++ .../agents/new_chat/test_rm_rmdir_cloud.py | 333 +++++++ .../agents/new_chat/test_state_reducers.py | 44 + surfsense_backend/tests/unit/db/__init__.py | 0 .../db/test_relax_revision_fks_migration.py | 83 ++ .../middleware/test_filesystem_middleware.py | 16 + .../test_kb_persistence_revisions.py | 309 +++++++ .../unit/middleware/test_knowledge_tree.py | 139 +++ .../middleware/test_local_folder_backend.py | 71 ++ .../tests/unit/routes/__init__.py | 0 .../routes/test_regenerate_from_message_id.py | 143 +++ .../unit/routes/test_revert_turn_route.py | 530 +++++++++++ .../services/test_revert_filesystem_tools.py | 370 ++++++++ .../tests/unit/tasks/__init__.py | 0 .../tests/unit/tasks/chat/__init__.py | 0 .../tasks/chat/test_extract_chunk_parts.py | 185 ++++ .../new-chat/[[...chat_id]]/page.tsx | 580 ++++++++++-- .../atoms/chat/agent-actions.atom.ts | 194 ++++ .../assistant-ui/assistant-message.tsx | 13 + .../assistant-ui/edit-message-dialog.tsx | 106 +++ .../assistant-ui/reasoning-message-part.tsx | 81 ++ .../assistant-ui/revert-turn-button.tsx | 232 +++++ .../assistant-ui/step-separator.tsx | 27 + .../components/assistant-ui/tool-fallback.tsx | 118 ++- .../components/free-chat/free-chat-page.tsx | 52 +- .../public-chat/public-chat-view.tsx | 2 + .../components/public-chat/public-thread.tsx | 2 + surfsense_web/contracts/enums/toolIcons.tsx | 85 ++ .../lib/apis/agent-actions-api.service.ts | 56 ++ surfsense_web/lib/chat/message-utils.ts | 6 +- surfsense_web/lib/chat/streaming-state.ts | 252 ++++- surfsense_web/lib/chat/thread-persistence.ts | 17 +- 60 files changed, 8068 insertions(+), 303 deletions(-) create mode 100644 surfsense_backend/alembic/versions/134_relax_revision_fks.py create mode 100644 surfsense_backend/alembic/versions/135_action_log_correlation_ids.py create mode 100644 surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py create mode 100644 surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py create mode 100644 surfsense_backend/tests/unit/db/__init__.py create mode 100644 surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py create mode 100644 surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py create mode 100644 surfsense_backend/tests/unit/middleware/test_knowledge_tree.py create mode 100644 surfsense_backend/tests/unit/routes/__init__.py create mode 100644 surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py create mode 100644 surfsense_backend/tests/unit/routes/test_revert_turn_route.py create mode 100644 surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py create mode 100644 surfsense_backend/tests/unit/tasks/__init__.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/__init__.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py create mode 100644 surfsense_web/atoms/chat/agent-actions.atom.ts create mode 100644 surfsense_web/components/assistant-ui/edit-message-dialog.tsx create mode 100644 surfsense_web/components/assistant-ui/reasoning-message-part.tsx create mode 100644 surfsense_web/components/assistant-ui/revert-turn-button.tsx create mode 100644 surfsense_web/components/assistant-ui/step-separator.tsx diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index c1bfcc538..a793f33d1 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -282,6 +282,14 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_ACTION_LOG=false # SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships +# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk +# content (typed reasoning blocks, tool-input deltas) and propagate the +# real tool_call_id to the SSE layer. When OFF, the stream falls back to +# the str-only text path and synthetic "call_" tool-call ids. +# Schema migrations 135/136 ship unconditionally because they are +# forward-compatible. +# SURFSENSE_ENABLE_STREAM_PARITY_V2=false + # Plugins # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names diff --git a/surfsense_backend/alembic/versions/134_relax_revision_fks.py b/surfsense_backend/alembic/versions/134_relax_revision_fks.py new file mode 100644 index 000000000..99b665426 --- /dev/null +++ b/surfsense_backend/alembic/versions/134_relax_revision_fks.py @@ -0,0 +1,139 @@ +"""134_relax_revision_fks + +Revision ID: 134 +Revises: 133 +Create Date: 2026-04-29 + +Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so +revisions survive the deletes they describe. + +Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE +hard-deleting a document via the ``rm`` tool (and likewise a +``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE`` +the snapshot row is wiped at the exact moment we need it most — revert +then has nothing to read and the operation becomes irreversible. + +Migration: + +* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK + ``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``. +* ``folder_revisions.folder_id``: same treatment. + +The ``search_space_id`` FK on both tables is left unchanged (still +``ON DELETE CASCADE``). When a search space is deleted, all documents, +folders, AND their revisions go together — that's the correct teardown +story. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy import inspect + +from alembic import op + +revision: str = "134" +down_revision: str | None = "133" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def _fk_name(bind, table: str, column: str) -> str | None: + """Return the (single) FK constraint name on ``table.column``, if any.""" + inspector = inspect(bind) + for fk in inspector.get_foreign_keys(table): + cols = fk.get("constrained_columns") or [] + if cols == [column]: + return fk.get("name") + return None + + +def upgrade() -> None: + bind = op.get_bind() + + # --- document_revisions.document_id -> nullable + SET NULL --------------- + fk_name = _fk_name(bind, "document_revisions", "document_id") + if fk_name: + op.drop_constraint(fk_name, "document_revisions", type_="foreignkey") + op.alter_column( + "document_revisions", + "document_id", + existing_type=sa.Integer(), + nullable=True, + ) + op.create_foreign_key( + "document_revisions_document_id_fkey", + "document_revisions", + "documents", + ["document_id"], + ["id"], + ondelete="SET NULL", + ) + + # --- folder_revisions.folder_id -> nullable + SET NULL ------------------- + fk_name = _fk_name(bind, "folder_revisions", "folder_id") + if fk_name: + op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey") + op.alter_column( + "folder_revisions", + "folder_id", + existing_type=sa.Integer(), + nullable=True, + ) + op.create_foreign_key( + "folder_revisions_folder_id_fkey", + "folder_revisions", + "folders", + ["folder_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + bind = op.get_bind() + + # Reinstating NOT NULL + CASCADE requires draining orphan rows first + # (any revision whose parent doc/folder has already been deleted). + op.execute("DELETE FROM document_revisions WHERE document_id IS NULL") + op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL") + + # --- document_revisions.document_id -> NOT NULL + CASCADE --------------- + fk_name = _fk_name(bind, "document_revisions", "document_id") + if fk_name: + op.drop_constraint(fk_name, "document_revisions", type_="foreignkey") + op.alter_column( + "document_revisions", + "document_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.create_foreign_key( + "document_revisions_document_id_fkey", + "document_revisions", + "documents", + ["document_id"], + ["id"], + ondelete="CASCADE", + ) + + # --- folder_revisions.folder_id -> NOT NULL + CASCADE ------------------- + fk_name = _fk_name(bind, "folder_revisions", "folder_id") + if fk_name: + op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey") + op.alter_column( + "folder_revisions", + "folder_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.create_foreign_key( + "folder_revisions_folder_id_fkey", + "folder_revisions", + "folders", + ["folder_id"], + ["id"], + ondelete="CASCADE", + ) diff --git a/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py b/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py new file mode 100644 index 000000000..9ae368b81 --- /dev/null +++ b/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py @@ -0,0 +1,82 @@ +"""135_action_log_correlation_ids + +Revision ID: 135 +Revises: 134 +Create Date: 2026-04-29 + +Action-log correlation-id cleanup. + +Background +---------- +``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes +the LangChain ``tool_call.id`` into that column today (see +``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch`` +joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from +``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was +never persisted. + +This migration introduces two new, correctly-named columns: + +* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held) +* ``chat_turn_id`` (the per-turn correlation id from + ``configurable.turn_id`` — used by the per-turn ``revert-turn`` route). + +Backfill copies the current ``turn_id`` values into ``tool_call_id`` so +existing joins keep working. The old ``turn_id`` column is left in place +for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware`` +keeps writing it (= ``tool_call_id``) for the same reason. + +Indexes +------- + +* ``ix_agent_action_log_tool_call_id`` — required by + ``_find_action_ids_batch`` (was on ``turn_id``). +* ``ix_agent_action_log_chat_turn_id`` — required by the + ``revert-turn/{chat_turn_id}`` query. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "135" +down_revision: str | None = "134" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "agent_action_log", + sa.Column("tool_call_id", sa.String(length=64), nullable=True), + ) + op.add_column( + "agent_action_log", + sa.Column("chat_turn_id", sa.String(length=64), nullable=True), + ) + + op.create_index( + "ix_agent_action_log_tool_call_id", + "agent_action_log", + ["tool_call_id"], + ) + op.create_index( + "ix_agent_action_log_chat_turn_id", + "agent_action_log", + ["chat_turn_id"], + ) + + op.execute( + "UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL" + ) + + +def downgrade() -> None: + op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log") + op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log") + op.drop_column("agent_action_log", "chat_turn_id") + op.drop_column("agent_action_log", "tool_call_id") diff --git a/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py b/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py new file mode 100644 index 000000000..8d4350424 --- /dev/null +++ b/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py @@ -0,0 +1,52 @@ +"""136_new_chat_message_turn_id + +Revision ID: 136 +Revises: 135 +Create Date: 2026-04-29 + +Persist the per-turn correlation id on each chat message. + +Background +---------- +LangGraph's checkpointer stores user-provided ``configurable.turn_id`` +in checkpoint metadata (see +``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To +support edit-from-arbitrary-position, the regenerate route needs to map +a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without +this column the mapping doesn't exist anywhere, so regenerate would +have to hardcode the "last 2 messages" rewind heuristic. + +This migration adds a nullable ``turn_id`` column to ``new_chat_messages`` +plus an index. Legacy rows have NULL — the regenerate route degrades +gracefully to the reload-last-two heuristic for those. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "136" +down_revision: str | None = "135" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "new_chat_messages", + sa.Column("turn_id", sa.String(length=64), nullable=True), + ) + op.create_index( + "ix_new_chat_messages_turn_id", + "new_chat_messages", + ["turn_id"], + ) + + +def downgrade() -> None: + op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages") + op.drop_column("new_chat_messages", "turn_id") diff --git a/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py b/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py new file mode 100644 index 000000000..d606a00f9 --- /dev/null +++ b/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py @@ -0,0 +1,74 @@ +"""137_unique_reverse_of_in_action_log + +Revision ID: 137 +Revises: 136 +Create Date: 2026-04-29 + +Protect ``agent_action_log.reverse_of`` against double inserts. Two +concurrent revert calls (single-action route + the per-turn batch +route, or two batch routes racing) both pass the +``_was_already_reverted`` SELECT and both insert their own +``_revert:*`` rows. The application-level idempotency check is racy +because there's no DB constraint backing it. + +This migration adds a partial unique index on ``reverse_of`` (PostgreSQL +``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises +``IntegrityError`` and the route can translate it to ``"already_reverted"`` +deterministically. + +The plain ``UniqueConstraint`` flavour can't be used because most +existing rows have ``reverse_of = NULL`` (only revert rows fill it), +and Postgres does treat NULL as distinct in unique indexes — but a +partial index is the cleanest expression of intent and works even on +older Postgres releases that distinguish NULL handling. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "137" +down_revision: str | None = "136" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +_INDEX_NAME = "ux_agent_action_log_reverse_of" + + +def upgrade() -> None: + # Defensively de-dup any pre-existing double-revert rows before + # adding the unique index. Keeps the OLDEST row (smallest id) and + # NULLs out the duplicates' ``reverse_of`` so they survive as audit + # trail but no longer claim to be the canonical revert. We do NOT + # delete them — operators can still inspect them via /actions. + op.execute( + """ + WITH dups AS ( + SELECT id, + reverse_of, + ROW_NUMBER() OVER ( + PARTITION BY reverse_of ORDER BY id ASC + ) AS rn + FROM agent_action_log + WHERE reverse_of IS NOT NULL + ) + UPDATE agent_action_log + SET reverse_of = NULL + WHERE id IN (SELECT id FROM dups WHERE rn > 1) + """ + ) + + op.create_index( + _INDEX_NAME, + "agent_action_log", + ["reverse_of"], + unique=True, + postgresql_where="reverse_of IS NOT NULL", + ) + + +def downgrade() -> None: + op.drop_index(_INDEX_NAME, table_name="agent_action_log") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index bfb94ba2d..fdd72ea92 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -724,7 +724,8 @@ def _build_compiled_agent_blocking( repair_mw = None if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: registered_names: set[str] = {t.name for t in tools} - # Tools owned by the standard deepagents middleware stack. + # Tools owned by the standard deepagents middleware stack and the + # SurfSense filesystem extension. registered_names |= { "write_todos", "ls", @@ -735,6 +736,14 @@ def _build_compiled_agent_blocking( "grep", "execute", "task", + "mkdir", + "cd", + "pwd", + "move_file", + "rm", + "rmdir", + "list_tree", + "execute_code", } repair_mw = ToolCallNameRepairMiddleware( registered_tool_names=registered_names, @@ -763,25 +772,51 @@ def _build_compiled_agent_blocking( # on every safe read-only call (``ls``, ``read_file``, ``grep``, # ``glob``, ``web_search`` …) and, on resume, replay the previous # reject decision into innocent calls. - # 2. ``connector_synthesized`` — deny rules for tools whose required - # connector is not connected to this space. Overrides #1. - # 3. (future) user-defined rules from ``agent_permission_rules`` table - # via the Agent Permissions UI. Loaded last so they override both. + # 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when + # the agent is operating against the user's real disk. Cloud mode + # has full revision-based revert via ``revert_service``, but + # desktop mode hits disk immediately with no undo, so an + # accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` / + # ``write_file`` is unrecoverable. This layer is forced on in + # desktop mode regardless of ``enable_permission`` because the + # safety net is non-negotiable. + # 3. ``connector_synthesized`` — deny rules for tools whose required + # connector is not connected to this space. Overrides #1/#2. + # 4. (future) user-defined rules from ``agent_permission_rules`` table + # via the Agent Permissions UI. Loaded last so they override all. permission_mw: PermissionMiddleware | None = None - if flags.enable_permission and not flags.disable_new_agent_stack: - synthesized = _synthesize_connector_deny_rules( - available_connectors=available_connectors, - enabled_tool_names={t.name for t in tools}, - ) - permission_mw = PermissionMiddleware( - rulesets=[ + is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER + permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack + # Build the middleware whenever it has work to do: either the user + # opted into the rule engine, OR we're in desktop mode and need the + # safety rules unconditionally. + if permission_enabled or is_desktop_fs: + rulesets: list[Ruleset] = [ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + ] + if is_desktop_fs: + rulesets.append( Ruleset( - rules=[Rule(permission="*", pattern="*", action="allow")], - origin="surfsense_defaults", - ), - Ruleset(rules=synthesized, origin="connector_synthesized"), - ], - ) + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", + ) + ) + if permission_enabled: + synthesized = _synthesize_connector_deny_rules( + available_connectors=available_connectors, + enabled_tool_names={t.name for t in tools}, + ) + rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized")) + permission_mw = PermissionMiddleware(rulesets=rulesets) # ActionLogMiddleware. Off by default until the ``agent_action_log`` # table is migrated. When enabled, persists one row per tool call @@ -938,6 +973,7 @@ def _build_compiled_agent_blocking( search_space_id=search_space_id, created_by_id=user_id, filesystem_mode=filesystem_mode, + thread_id=thread_id, ) if filesystem_mode == FilesystemMode.CLOUD else None, diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index 55525abc5..f58bf0dd7 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false + SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events Master kill-switch (overrides everything else): @@ -86,6 +87,15 @@ class AgentFeatureFlags: False # Backend ships before UI; route returns 503 until this flips ) + # Streaming parity v2 — opt in to LangChain's structured + # ``AIMessageChunk`` content (typed reasoning blocks, tool-input + # deltas) and propagate the real ``tool_call_id`` to the SSE layer. + # When OFF the ``stream_new_chat`` task falls back to the str-only + # text path and the synthetic ``call_`` tool-call id (no + # ``langchainToolCallId`` propagation). Schema migrations 135/136 + # ship unconditionally because they're forward-compatible. + enable_stream_parity_v2: bool = False + # Plugins enable_plugin_loader: bool = False @@ -139,6 +149,10 @@ class AgentFeatureFlags: # Snapshot / revert enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + # Streaming parity v2 + enable_stream_parity_v2=_env_bool( + "SURFSENSE_ENABLE_STREAM_PARITY_V2", False + ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/new_chat/filesystem_state.py index 18952ed6f..f54ada76e 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_state.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_state.py @@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics: * ``cwd`` — current working directory (per-thread checkpointed). * ``staged_dirs`` — pending mkdir requests (cloud only). +* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs. * ``pending_moves`` — pending move_file requests (cloud only). +* ``pending_deletes`` — pending ``rm`` requests (cloud only). +* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only). * ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads. * ``dirty_paths`` — paths whose state file content differs from DB. +* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for + dirty paths; used to bind the per-path snapshot to an action_id. * ``kb_priority`` — top-K priority hints rendered into a system message. * ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting. * ``kb_anon_doc`` — Redis-loaded anonymous document (if any). @@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import ( ) -class PendingMove(TypedDict): - """A staged move_file operation pending end-of-turn commit.""" +class PendingMove(TypedDict, total=False): + """A staged move_file operation pending end-of-turn commit. + + ``tool_call_id`` is optional for backward compatibility with checkpoints + written before the snapshot/revert pipeline was wired up; new entries + always include it so the persistence body can resolve an action_id. + """ source: str dest: str overwrite: bool + tool_call_id: str + + +class PendingDelete(TypedDict, total=False): + """A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit. + + ``tool_call_id`` is required for new entries (it's the binding key used + by :class:`KnowledgeBasePersistenceMiddleware` to find the matching + :class:`AgentActionLog` row and bind the snapshot to it). Marked + ``total=False`` only to tolerate older checkpoint payloads. + """ + + path: str + tool_call_id: str class KbPriorityEntry(TypedDict, total=False): @@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState): staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]] """mkdir paths staged for end-of-turn folder creation (cloud only).""" + staged_dir_tool_calls: NotRequired[ + Annotated[dict[str, str], _dict_merge_with_tombstones_reducer] + ] + """``path -> tool_call_id`` sidecar for ``staged_dirs``. + + Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the + :class:`FolderRevision` snapshot to the originating ``mkdir`` action. + Kept separate from ``staged_dirs`` (which stays a unique-string list) + to avoid breaking ``_add_unique_reducer`` semantics. + """ + pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]] """move_file ops staged for end-of-turn commit (cloud only).""" + pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]] + """``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only). + + Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path + uniqueness is enforced inside the commit body, not the reducer (we keep + ``tool_call_id`` per occurrence so snapshot binding works). + """ + + pending_dir_deletes: NotRequired[ + Annotated[list[PendingDelete], _list_append_reducer] + ] + """``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only). + + Same shape as :data:`pending_deletes`. Commit body re-verifies the + folder is empty (in-DB AND with this turn's pending changes accounted + for) before issuing the DELETE. + """ + doc_id_by_path: NotRequired[ Annotated[dict[str, int], _dict_merge_with_tombstones_reducer] ] @@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState): dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]] """Paths whose ``state["files"]`` content has been modified this turn.""" + dirty_path_tool_calls: NotRequired[ + Annotated[dict[str, str], _dict_merge_with_tombstones_reducer] + ] + """``path -> latest tool_call_id`` sidecar for ``dirty_paths``. + + The persistence body coalesces multiple writes/edits to the same path + into one snapshot per turn. This map captures the most-recent + ``tool_call_id`` so the resulting :class:`DocumentRevision` is bound + to the latest action_id (the one the user is most likely to revert). + """ + kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]] """Top-K priority hints rendered as a system message before the user turn.""" @@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState): __all__ = [ "KbAnonDoc", "KbPriorityEntry", + "PendingDelete", "PendingMove", "SurfSenseFilesystemState", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/new_chat/middleware/action_log.py index 3675064e8..716a1616c 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/action_log.py +++ b/surfsense_backend/app/agents/new_chat/middleware/action_log.py @@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from langchain.agents.middleware import AgentMiddleware +from langchain_core.callbacks import adispatch_custom_event from langchain_core.messages import ToolMessage from app.agents.new_chat.feature_flags import get_flags @@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware): result=result, ) + tool_call_id = _resolve_tool_call_id(request) + chat_turn_id = _resolve_chat_turn_id(request) + row = AgentActionLog( thread_id=self._thread_id, user_id=self._user_id, search_space_id=self._search_space_id, - turn_id=_resolve_turn_id(request), + # ``turn_id`` is the deprecated alias of ``tool_call_id`` + # kept for one release for safe rollback. New consumers + # should read ``tool_call_id`` directly. + turn_id=tool_call_id, + tool_call_id=tool_call_id, + chat_turn_id=chat_turn_id, message_id=_resolve_message_id(request), tool_name=tool_name, args=args_payload, @@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware): async with shielded_async_session() as session: session.add(row) await session.commit() + row_id = int(row.id) if row.id is not None else None + row_created_at = row.created_at except Exception: logger.warning( "ActionLogMiddleware failed to persist action log row", exc_info=True, ) + return + + # Surface a side-channel SSE event so the chat tool card can + # render a Revert button immediately after the row is durable. + # ``stream_new_chat`` translates this into a + # ``data-action-log`` SSE event. We DO NOT include the + # ``reverse_descriptor`` payload here; only a presence flag. + try: + await adispatch_custom_event( + "action_log", + { + "id": row_id, + "lc_tool_call_id": tool_call_id, + "chat_turn_id": chat_turn_id, + "tool_name": tool_name, + "reversible": bool(reversible), + "reverse_descriptor_present": reverse_descriptor is not None, + "created_at": row_created_at.isoformat() + if row_created_at + else None, + "error": error_payload is not None, + }, + ) + except Exception: + logger.debug( + "ActionLogMiddleware failed to dispatch action_log event", + exc_info=True, + ) def _render_reverse( self, @@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None: } -def _resolve_turn_id(request: Any) -> str | None: +def _resolve_tool_call_id(request: Any) -> str | None: + """Return the LangChain ``tool_call.id`` for this request, if any.""" try: call = getattr(request, "tool_call", None) or {} if isinstance(call, dict): @@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None: return None +# Deprecated alias kept for one release. Old callers and tests treated +# ``turn_id`` as if it carried the LangChain tool_call id; the new column +# lives under ``tool_call_id``. Both resolve to the same value today. +_resolve_turn_id = _resolve_tool_call_id + + +def _resolve_chat_turn_id(request: Any) -> str | None: + """Return ``configurable.turn_id`` for this request, if accessible. + + ``ToolRuntime.config`` is exposed by LangGraph (see + ``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id + lives at ``runtime.config["configurable"]["turn_id"]``. + """ + try: + runtime = getattr(request, "runtime", None) + if runtime is None: + return None + config = getattr(runtime, "config", None) + if not isinstance(config, dict): + return None + configurable = config.get("configurable") + if not isinstance(configurable, dict): + return None + value = configurable.get("turn_id") + if isinstance(value, str) and value: + return value + except Exception: # pragma: no cover - defensive + pass + return None + + def _resolve_message_id(request: Any) -> str | None: """Tool-call IDs serve as best-available message correlator at this layer.""" - return _resolve_turn_id(request) + return _resolve_tool_call_id(request) def _resolve_result_id(result: Any) -> str | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 62316d69e..c46eb98a5 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`). - cd(path): change the current working directory. - pwd(): print the current working directory. - move_file(source, dest): move/rename a file under `/documents/`. +- rm(path): delete a single file under `/documents/` (no `-r`). +- rmdir(path): delete an empty directory under `/documents/`. - list_tree(path, max_depth, page_size): recursively list files/folders. ## Persistence Rules @@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`). `/documents/temp_scratch.md`) are **discarded** at end of turn — use this prefix for any scratch/working content you do NOT want saved. - All other paths (outside `/documents/` and not `temp_*`) are rejected. -- mkdir/move_file are staged this turn and committed at end of turn alongside - any new/edited documents. +- mkdir/move_file/rm/rmdir are staged this turn and committed at end of + turn alongside any new/edited documents. Snapshot/revert is enabled + for every destructive operation when action logging is on. ## Reading Documents Efficiently @@ -176,6 +179,8 @@ directory (`cwd`). - cd(path): change the current working directory. - pwd(): print the current working directory. - move_file(source, dest): move/rename a file. +- rm(path): delete a single file from disk (no `-r`). NOT reversible. +- rmdir(path): delete an empty directory from disk. NOT reversible. - list_tree(path, max_depth, page_size): recursively list files/folders. ## Workflow Tips @@ -184,6 +189,8 @@ directory (`cwd`). - For large trees, prefer `list_tree` then `grep` then `read_file` over brute-force directory traversal. - Cross-mount moves are not supported. +- Desktop deletes hit disk immediately and cannot be undone via the + agent's revert flow — confirm before calling `rm`/`rmdir`. """ ) @@ -355,6 +362,42 @@ Notes: - Parent folders are created as needed. """ +_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`. + +Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion +for end-of-turn commit; the row is removed only after the agent's turn +finishes successfully. + +Args: +- path: absolute or relative file path. Cannot point at a directory — use + `rmdir` for empty folders. Cannot target the root or `/documents`. + +Notes: +- The action is reversible via the per-action revert flow when action + logging is enabled. +- The anonymous uploaded document is read-only and cannot be deleted. +""" + +_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`. + +Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive +deletion (`rm -r`) is intentionally NOT supported — clear contents with +`rm` first. + +Args: +- path: absolute or relative directory path. Cannot target the root, + `/documents`, the current cwd, or any ancestor of cwd (use `cd` to + move out first). + +Notes: +- Emptiness is evaluated against the post-staged view, so a same-turn + `rm /a/x.md` followed by `rmdir /a` is fine. +- If the directory was added in this same turn via `mkdir` and never + committed, the staged mkdir is dropped instead of issuing a delete. +- The action is reversible via the per-action revert flow when action + logging is enabled. +""" + # --- desktop-only ---------------------------------------------------------- _DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. @@ -421,6 +464,28 @@ Notes: - Parent folders are created as needed. """ +_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk. + +Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits +disk immediately. Desktop deletes are NOT reversible via the agent's +revert flow. + +Args: +- path: absolute mount-prefixed file path. Cannot point at a directory — + use `rmdir` for empty folders. +""" + +_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk. + +Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive +deletion is NOT supported. The deletion hits disk immediately and is +NOT reversible via the agent's revert flow. + +Args: +- path: absolute mount-prefixed directory path. Cannot target the mount + root or any directory containing files/subfolders. +""" + def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: """Pick the active-mode description for every filesystem tool.""" @@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: "mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION, "cd": SURFSENSE_CD_TOOL_DESCRIPTION, "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + "rm": _CLOUD_RM_TOOL_DESCRIPTION, + "rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION, } return { "ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION, @@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: "mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION, "cd": SURFSENSE_CD_TOOL_DESCRIPTION, "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + "rm": _DESKTOP_RM_TOOL_DESCRIPTION, + "rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION, } @@ -476,6 +545,21 @@ def _basename(path: str) -> str: return path.rsplit("/", 1)[-1] +def _is_ancestor_of(candidate: str, target: str) -> bool: + """True iff ``candidate`` is a strict ancestor directory of ``target``. + + ``target`` itself is NOT considered an ancestor (use equality for that). + Both paths are assumed to be canonicalised, absolute, and free of + trailing slashes (except the root ``/``). + """ + if not candidate.startswith("/") or not target.startswith("/"): + return False + if candidate == target: + return False + prefix = candidate.rstrip("/") + "/" + return target.startswith(prefix) + + class SurfSenseFilesystemMiddleware(FilesystemMiddleware): """SurfSense-specific filesystem middleware (cloud + desktop).""" @@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): self.tools.append(self._create_cd_tool()) self.tools.append(self._create_pwd_tool()) self.tools.append(self._create_move_file_tool()) + self.tools.append(self._create_rm_tool()) + self.tools.append(self._create_rmdir_tool()) self.tools.append(self._create_list_tree_tool()) if self._sandbox_available: self.tools.append(self._create_execute_code_tool()) @@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): } if self._is_cloud(): update["dirty_paths"] = [path] + update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} return Command(update=update) def sync_write_file( @@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): } if self._is_cloud(): update["dirty_paths"] = [path] + update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} if doc_id_to_attach is not None: update["doc_id_by_path"] = {path: doc_id_to_attach} return Command(update=update) @@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return Command( update={ "staged_dirs": [validated], + "staged_dir_tool_calls": { + validated: runtime.tool_call_id, + }, "messages": [ ToolMessage( content=( @@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): files_update: dict[str, Any] = {source: None, dest: source_file_data} update: dict[str, Any] = { "files": files_update, - "pending_moves": [{"source": source, "dest": dest, "overwrite": False}], + "pending_moves": [ + { + "source": source, + "dest": dest, + "overwrite": False, + "tool_call_id": runtime.tool_call_id, + } + ], "messages": [ ToolMessage( content=( @@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): update["dirty_paths"] = new_dirty return Command(update=update) + # ------------------------------------------------------------------ tool: rm + + def _create_rm_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION + ) + + async def async_rm( + path: Annotated[ + str, + "Absolute or relative path to the file to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + if not path or not path.strip(): + return "Error: path is required." + + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if validated in ("/", DOCUMENTS_ROOT): + return f"Error: refusing to rm '{validated}'." + if not validated.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud rm must target a path under /documents/ " + f"(got '{validated}')." + ) + + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict) and str(anon.get("path") or "") == validated: + return "Error: the anonymous uploaded document is read-only." + + # Refuse if the path looks like a directory. + staged_dirs = list(runtime.state.get("staged_dirs") or []) + if validated in staged_dirs: + return ( + f"Error: '{validated}' is a directory. Use rmdir for " + "empty directories." + ) + pending_dir_deletes = list( + runtime.state.get("pending_dir_deletes") or [] + ) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_dir_deletes + ): + return f"Error: '{validated}' is already queued for rmdir." + + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + # Detect "is a directory" via `ls`: if the path lists + # children we know it's a folder. Otherwise we still + # need to confirm it's a real file before staging. + children = await backend.als_info(validated) + if children: + return ( + f"Error: '{validated}' is a directory. Use rmdir for " + "empty directories." + ) + + # Already queued for delete this turn? + pending_deletes = list(runtime.state.get("pending_deletes") or []) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_deletes + ): + return f"'{validated}' is already queued for deletion." + + # Resolve doc_id (best-effort): file in state or DB. + files_state = runtime.state.get("files") or {} + doc_id_by_path = runtime.state.get("doc_id_by_path") or {} + resolved_doc_id: int | None = doc_id_by_path.get(validated) + if ( + validated not in files_state + and resolved_doc_id is None + and isinstance(backend, KBPostgresBackend) + ): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: file '{validated}' not found." + _, resolved_doc_id = loaded + + files_update: dict[str, Any] = {validated: None} + update: dict[str, Any] = { + "pending_deletes": [ + { + "path": validated, + "tool_call_id": runtime.tool_call_id, + } + ], + "files": files_update, + "doc_id_by_path": {validated: None}, + "messages": [ + ToolMessage( + content=( + f"Staged delete of '{validated}' (will commit at " + "end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + + # Drop the path from dirty_paths so a same-turn write+rm + # doesn't recreate the doc at commit time. + dirty_paths = list(runtime.state.get("dirty_paths") or []) + if validated in dirty_paths: + new_dirty: list[Any] = [_CLEAR] + for entry in dirty_paths: + if entry != validated: + new_dirty.append(entry) + update["dirty_paths"] = new_dirty + update["dirty_path_tool_calls"] = {validated: None} + + return Command(update=update) + + # Desktop mode — hit disk immediately. + backend = self._get_backend(runtime) + adelete = getattr(backend, "adelete_file", None) + if not callable(adelete): + return "Error: rm is not supported by the active backend." + res: WriteResult = await adelete(validated) + if res.error: + return res.error + update_desktop: dict[str, Any] = { + "files": {validated: None}, + "messages": [ + ToolMessage( + content=f"Deleted file '{res.path or validated}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } + return Command(update=update_desktop) + + def sync_rm( + path: Annotated[ + str, + "Absolute or relative path to the file to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_rm(path, runtime)) + + return StructuredTool.from_function( + name="rm", + description=tool_description, + func=sync_rm, + coroutine=async_rm, + ) + + # ------------------------------------------------------------------ tool: rmdir + + def _create_rmdir_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION + ) + + async def async_rmdir( + path: Annotated[ + str, + "Absolute or relative path of the empty directory to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + if not path or not path.strip(): + return "Error: path is required." + + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if validated in ("/", DOCUMENTS_ROOT): + return f"Error: refusing to rmdir '{validated}'." + if not validated.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud rmdir must target a path under /documents/ " + f"(got '{validated}')." + ) + + cwd = self._current_cwd(runtime) + if validated == cwd or _is_ancestor_of(validated, cwd): + return ( + f"Error: cannot rmdir '{validated}' because the current " + "cwd is at or under it. cd out first." + ) + + staged_dirs = list(runtime.state.get("staged_dirs") or []) + pending_dir_deletes = list( + runtime.state.get("pending_dir_deletes") or [] + ) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_dir_deletes + ): + return f"'{validated}' is already queued for deletion." + + backend = self._get_backend(runtime) + + # The path must currently exist either in DB folder paths or + # in staged_dirs. We rely on KBPostgresBackend.als_info (which + # already accounts for pending deletes/moves) to evaluate + # both existence and emptiness against the post-staged view. + exists_in_staged = validated in staged_dirs + children: list[Any] = [] + if isinstance(backend, KBPostgresBackend): + children = list(await backend.als_info(validated)) + + # Detect "is a file" — if als_info returns no children but + # the path is actually a file, we should reject. We use + # _load_file_data to disambiguate file vs missing folder. + if ( + isinstance(backend, KBPostgresBackend) + and not children + and not exists_in_staged + ): + loaded = await backend._load_file_data(validated) + if loaded is not None: + return ( + f"Error: '{validated}' is a file. Use rm to delete files." + ) + # Confirm folder exists in DB by checking the parent listing. + parent = posixpath.dirname(validated) or "/" + parent_listing = await backend.als_info(parent) + parent_has_dir = any( + info.get("path") == validated and info.get("is_dir") + for info in parent_listing + ) + if not parent_has_dir: + return f"Error: directory '{validated}' not found." + + if children: + return ( + f"Error: directory '{validated}' is not empty. " + "Remove contents first." + ) + + # Same-turn mkdir un-stage: drop the staged_dirs entry + # entirely and skip queuing a DB delete (nothing was ever + # committed). + if exists_in_staged: + rest = [d for d in staged_dirs if d != validated] + return Command( + update={ + "staged_dirs": [_CLEAR, *rest], + "staged_dir_tool_calls": {validated: None}, + "messages": [ + ToolMessage( + content=(f"Un-staged directory '{validated}'."), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + return Command( + update={ + "pending_dir_deletes": [ + { + "path": validated, + "tool_call_id": runtime.tool_call_id, + } + ], + "messages": [ + ToolMessage( + content=( + f"Staged rmdir of '{validated}' (will commit " + "at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + # Desktop mode — hit disk immediately. + backend = self._get_backend(runtime) + armdir = getattr(backend, "armdir", None) + if not callable(armdir): + return "Error: rmdir is not supported by the active backend." + res: WriteResult = await armdir(validated) + if res.error: + return res.error + return Command( + update={ + "messages": [ + ToolMessage( + content=f"Deleted directory '{res.path or validated}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + def sync_rmdir( + path: Annotated[ + str, + "Absolute or relative path of the empty directory to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_rmdir(path, runtime)) + + return StructuredTool.from_function( + name="rmdir", + description=tool_description, + func=sync_rmdir, + coroutine=async_rmdir, + ) + # ------------------------------------------------------------------ tool: list_tree def _create_list_tree_tool(self) -> BaseTool: diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py index 378b83950..d577441dd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -1,16 +1,29 @@ """End-of-turn persistence for the cloud-mode SurfSense filesystem. This middleware runs ``aafter_agent`` once per turn (cloud only). It commits -all staged folder creations, file moves, and content writes/edits to -Postgres in a single ordered pass: +all staged folder creations, file moves, content writes/edits, file deletes +(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered +pass: 1. Materialize ``staged_dirs`` into ``Folder`` rows. 2. Apply ``pending_moves`` in order (chained moves resolved via ``doc_id_by_path``). 3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move - sequences commit at the final path. + sequences commit at the final path. Paths queued for ``rm`` this turn + are dropped here so a write+rm sequence doesn't recreate the doc. 4. Commit content writes / edits for ``/documents/*`` paths, skipping ``temp_*`` basenames. +5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory + deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works. +6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against + the post-step-5 DB state. + +When ``flags.enable_action_log`` is on every destructive op also writes a +``DocumentRevision`` / ``FolderRevision`` snapshot bound to the +originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir`` +share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails +the DELETE rolls back and we surface the error rather than silently +making the data irreversible. The commit body is exposed as a free function ``commit_staged_filesystem_state`` so the optional stream-task fallback (``stream_new_chat.py``) can call the @@ -25,12 +38,13 @@ from typing import Any from fractional_indexing import generate_key_between from langchain.agents.middleware import AgentMiddleware, AgentState -from langchain_core.callbacks import dispatch_custom_event +from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event from langgraph.runtime import Runtime -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState from app.agents.new_chat.path_resolver import ( @@ -41,10 +55,13 @@ from app.agents.new_chat.path_resolver import ( ) from app.agents.new_chat.state_reducers import _CLEAR from app.db import ( + AgentActionLog, Chunk, Document, + DocumentRevision, DocumentType, Folder, + FolderRevision, shielded_async_session, ) from app.indexing_pipeline.document_chunker import chunk_text @@ -123,6 +140,47 @@ async def _ensure_folder_hierarchy( return parent_id +async def _resolve_folder_id( + session: AsyncSession, + *, + search_space_id: int, + folder_parts: list[str], +) -> int | None: + """Look up an existing folder chain without creating anything. + + Returns ``None`` if any segment is missing. Used by ``rmdir`` snapshot + capture and by parent-folder lookup at ``rmdir`` commit time. + """ + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(str(raw)) + query = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + query = ( + query.where(Folder.parent_id.is_(None)) + if parent_id is None + else query.where(Folder.parent_id == parent_id) + ) + result = await session.execute(query) + folder = result.scalar_one_or_none() + if folder is None: + return None + parent_id = folder.id + return parent_id + + +def _split_folder_path(folder_path: str) -> list[str]: + """Return the folder segments under ``/documents/`` for a path.""" + if not folder_path.startswith(DOCUMENTS_ROOT): + return [] + rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") + return [p for p in rel.split("/") if p] + + # --------------------------------------------------------------------------- # Document helpers # --------------------------------------------------------------------------- @@ -331,6 +389,298 @@ async def _apply_move( return {"id": document.id, "source": source, "dest": dest, "title": new_title} +# --------------------------------------------------------------------------- +# Action log binding helpers +# --------------------------------------------------------------------------- + + +async def _find_action_ids_batch( + session: AsyncSession, + *, + thread_id: int | None, + tool_call_ids: set[str], +) -> dict[str, int]: + """Resolve ``tool_call_id -> AgentActionLog.id`` in a single query. + + Returns an empty dict when ``thread_id`` or ``tool_call_ids`` are + missing — callers treat that as "no binding available" and write the + revision with ``agent_action_id = NULL``. + """ + if thread_id is None or not tool_call_ids: + return {} + rows = await session.execute( + select(AgentActionLog.id, AgentActionLog.tool_call_id).where( + AgentActionLog.thread_id == thread_id, + AgentActionLog.tool_call_id.in_(list(tool_call_ids)), + ) + ) + mapping: dict[str, int] = {} + for row in rows.all(): + if row.tool_call_id and row.id: + mapping[str(row.tool_call_id)] = int(row.id) + return mapping + + +async def _mark_action_reversible( + session: AsyncSession, + *, + action_id: int | None, +) -> None: + """Flip ``agent_action_log.reversible = TRUE`` for ``action_id``. + + Best-effort: caller may invoke from inside a SAVEPOINT and treat + failure as a soft demotion (snapshot persists, just no Revert button). + + Callers should also call ``_dispatch_reversibility_update`` (defined + below) AFTER the enclosing SAVEPOINT block exits successfully so the + chat tool card can light up its Revert button without + re-fetching ``GET /threads/.../actions``. Dispatching from inside the + SAVEPOINT would risk emitting "reversible=true" for rows whose + update gets rolled back if the surrounding destructive op fails. + """ + if action_id is None: + return + await session.execute( + update(AgentActionLog) + .where(AgentActionLog.id == action_id) + .values(reversible=True) + ) + + +async def _dispatch_reversibility_update(action_id: int | None) -> None: + """Best-effort dispatch of an ``action_log_updated`` custom event. + + Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so + the chat tool card can flip its Revert button live. Defensive: + failures are logged at debug level and swallowed; the + REST endpoint ``GET /threads/.../actions`` is still authoritative. + + .. warning:: + Inside :func:`commit_staged_filesystem_state` we DEFER all + dispatches until the outer ``session.commit()`` succeeds — see + the ``deferred_dispatches`` queue in that function. Dispatching + from inside a SAVEPOINT block while the outer transaction is + still pending would emit ``reversible=true`` for rows whose + snapshots get rolled back if the outer commit fails. Direct + callers (e.g. the optional stream-task fallback) that own the + full session lifetime can still call this helper inline. + """ + if action_id is None: + return + try: + await adispatch_custom_event( + "action_log_updated", + {"id": int(action_id), "reversible": True}, + ) + except Exception: + logger.debug( + "kb_persistence.aafter_agent failed to dispatch action_log_updated", + exc_info=True, + ) + + +# --------------------------------------------------------------------------- +# Snapshot helpers +# --------------------------------------------------------------------------- +# +# Best-effort helpers swallow + log so a snapshot failure can never break +# the destructive op for non-destructive tools (write/edit/move/mkdir). +# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the +# destructive DELETE — failure aborts the savepoint and leaves the doc / +# folder intact, so revertable ops never become irreversible silently. + + +def _doc_revision_payload( + doc: Document, + *, + chunks_before: list[dict[str, str]] | None = None, +) -> dict[str, Any]: + """Pre-mutation field map for ``DocumentRevision``.""" + metadata = dict(doc.document_metadata or {}) + return { + "content_before": doc.content, + "title_before": doc.title, + "folder_id_before": doc.folder_id, + "chunks_before": chunks_before, + "metadata_before": metadata or None, + } + + +async def _load_chunks_for_snapshot( + session: AsyncSession, *, doc_id: int +) -> list[dict[str, str]]: + rows = await session.execute( + select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id) + ) + return [{"content": row.content} for row in rows.all() if row.content is not None] + + +async def _snapshot_document_pre_write( + session: AsyncSession, + *, + doc: Document, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort snapshot ahead of an in-place ``write_file``/``edit_file``. + + When ``deferred_dispatches`` is provided, on success the action id + is APPENDED to it and the SSE dispatch is left to the caller (so it + can be flushed only after the outer ``session.commit()`` succeeds). + """ + try: + async with session.begin_nested(): + chunks = await _load_chunks_for_snapshot(session, doc_id=doc.id) + payload = _doc_revision_payload(doc, chunks_before=chunks) + rev = DocumentRevision( + document_id=doc.id, + search_space_id=search_space_id, + created_by_turn_id=turn_id, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-write snapshot for doc=%s failed: %s", + doc.id, + exc, + ) + return None + + +async def _snapshot_document_pre_create( + session: AsyncSession, + *, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort placeholder revision for a fresh ``write_file`` create. + + ``document_id`` is patched in by the caller after the new doc is + flushed and gets an ID; the placeholder lets us bind the action_id + even though no parent row exists yet. + """ + try: + async with session.begin_nested(): + rev = DocumentRevision( + document_id=None, + search_space_id=search_space_id, + content_before=None, + title_before=None, + folder_id_before=None, + chunks_before=None, + metadata_before=None, + created_by_turn_id=turn_id, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning("kb_persistence: pre-create snapshot failed: %s", exc) + return None + + +async def _snapshot_document_pre_move( + session: AsyncSession, + *, + doc: Document, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort snapshot ahead of a ``move_file``.""" + try: + async with session.begin_nested(): + payload = _doc_revision_payload(doc, chunks_before=None) + rev = DocumentRevision( + document_id=doc.id, + search_space_id=search_space_id, + created_by_turn_id=turn_id, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-move snapshot for doc=%s failed: %s", + doc.id, + exc, + ) + return None + + +async def _snapshot_folder_pre_mkdir( + session: AsyncSession, + *, + folder: Folder, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort placeholder for an ``mkdir`` (revert deletes the folder). + + The "before" state is "did not exist", so all ``*_before`` fields are + NULL — revert routes by ``tool_name == "mkdir"`` and DELETEs. + """ + try: + async with session.begin_nested(): + rev = FolderRevision( + folder_id=folder.id, + search_space_id=search_space_id, + name_before=None, + parent_id_before=None, + position_before=None, + created_by_turn_id=turn_id, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-mkdir snapshot for folder=%s failed: %s", + folder.id, + exc, + ) + return None + + # --------------------------------------------------------------------------- # Commit body # --------------------------------------------------------------------------- @@ -342,12 +692,20 @@ async def commit_staged_filesystem_state( search_space_id: int, created_by_id: str | None, filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + thread_id: int | None = None, dispatch_events: bool = True, ) -> dict[str, Any] | None: """Commit all staged filesystem changes; return the state delta for reducers. Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and the optional stream-task fallback. + + When ``flags.enable_action_log`` is on every destructive op also writes + a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the + originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot + durability is best-effort for non-destructive ops and STRICT for + ``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot + failure aborts the delete). """ if filesystem_mode != FilesystemMode.CLOUD: return None @@ -360,8 +718,20 @@ async def commit_staged_filesystem_state( files: dict[str, Any] = state_dict.get("files") or {} staged_dirs: list[str] = list(state_dict.get("staged_dirs") or []) + staged_dir_tool_calls: dict[str, str] = dict( + state_dict.get("staged_dir_tool_calls") or {} + ) pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or []) + pending_deletes: list[dict[str, Any]] = list( + state_dict.get("pending_deletes") or [] + ) + pending_dir_deletes: list[dict[str, Any]] = list( + state_dict.get("pending_dir_deletes") or [] + ) dirty_paths: list[str] = list(state_dict.get("dirty_paths") or []) + dirty_path_tool_calls: dict[str, str] = dict( + state_dict.get("dirty_path_tool_calls") or {} + ) doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {}) kb_anon_doc = state_dict.get("kb_anon_doc") @@ -374,32 +744,112 @@ async def commit_staged_filesystem_state( return { "dirty_paths": [_CLEAR], "staged_dirs": [_CLEAR], + "staged_dir_tool_calls": {_CLEAR: True}, "pending_moves": [_CLEAR], + "pending_deletes": [_CLEAR], + "pending_dir_deletes": [_CLEAR], + "dirty_path_tool_calls": {_CLEAR: True}, "files": dict.fromkeys(temp_paths), } - if not (staged_dirs or pending_moves or dirty_paths): + if not ( + staged_dirs + or pending_moves + or dirty_paths + or pending_deletes + or pending_dir_deletes + ): return None + flags = get_flags() + snapshot_enabled = flags.enable_action_log + + # De-duplicate pending deletes per-path while preserving the latest + # tool_call_id (the one the user is most likely to revert via the UI). + file_delete_paths: dict[str, str] = {} + for entry in pending_deletes: + if not isinstance(entry, dict): + continue + path = str(entry.get("path") or "") + if path: + file_delete_paths[path] = str(entry.get("tool_call_id") or "") + dir_delete_paths: dict[str, str] = {} + for entry in pending_dir_deletes: + if not isinstance(entry, dict): + continue + path = str(entry.get("path") or "") + if path: + dir_delete_paths[path] = str(entry.get("tool_call_id") or "") + committed_creates: list[dict[str, Any]] = [] committed_updates: list[dict[str, Any]] = [] + committed_deletes: list[dict[str, Any]] = [] + committed_folder_deletes: list[dict[str, Any]] = [] discarded: list[str] = [] applied_moves: list[dict[str, Any]] = [] doc_id_path_tombstones: dict[str, int | None] = {} tree_changed = False + # Reversibility-flip dispatches are deferred until AFTER the outer + # ``session.commit()`` succeeds. Dispatching from inside the + # SAVEPOINT chain while the outer transaction is still pending + # would emit ``reversible=true`` for rows whose snapshots get rolled + # back if the final commit raises. Snapshot helpers append on + # success; we drain this list after commit and silently abandon it + # on rollback so the UI stays consistent with durable state. + deferred_dispatches: list[int] = [] try: async with shielded_async_session() as session: + # ------------------------------------------------------------------ + # Resolve action-id bindings up front. One SELECT per turn for all + # tool_call_ids, NOT one per op — important because a turn that + # touches 50 paths would otherwise issue 50 lookups. + # ------------------------------------------------------------------ + action_id_by_call: dict[str, int] = {} + if snapshot_enabled and thread_id is not None: + tool_call_ids: set[str] = set() + tool_call_ids.update( + tcid for tcid in staged_dir_tool_calls.values() if tcid + ) + for move in pending_moves: + tcid = str(move.get("tool_call_id") or "") + if tcid: + tool_call_ids.add(tcid) + tool_call_ids.update( + tcid for tcid in dirty_path_tool_calls.values() if tcid + ) + tool_call_ids.update( + tcid for tcid in file_delete_paths.values() if tcid + ) + tool_call_ids.update(tcid for tcid in dir_delete_paths.values() if tcid) + action_id_by_call = await _find_action_ids_batch( + session, + thread_id=thread_id, + tool_call_ids=tool_call_ids, + ) + + def _action_id_for(tool_call_id: str | None) -> int | None: + if not snapshot_enabled or not tool_call_id: + return None + return action_id_by_call.get(str(tool_call_id)) + + turn_id_for_revision = ( + next(iter(action_id_by_call), None) if action_id_by_call else None + ) + + # ------------------------------------------------------------------ + # 1. staged_dirs -> Folder rows. Snapshot post-flush so the new + # folder_id is available for the FK. + # ------------------------------------------------------------------ for folder_path in staged_dirs: if not isinstance(folder_path, str): continue if not folder_path.startswith(DOCUMENTS_ROOT): continue - rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") - folder_parts_full = [p for p in rel.split("/") if p] + folder_parts_full = _split_folder_path(folder_path) if not folder_parts_full: continue - await _ensure_folder_hierarchy( + folder_id = await _ensure_folder_hierarchy( session, search_space_id=search_space_id, created_by_id=created_by_id, @@ -407,7 +857,61 @@ async def commit_staged_filesystem_state( ) tree_changed = True + if snapshot_enabled and folder_id is not None: + tcid = staged_dir_tool_calls.get(folder_path) + action_id = _action_id_for(tcid) + if action_id is not None: + # Re-read the folder for the snapshot. + result = await session.execute( + select(Folder).where(Folder.id == folder_id) + ) + folder_row = result.scalar_one_or_none() + if folder_row is not None: + await _snapshot_folder_pre_mkdir( + session, + folder=folder_row, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) + + # ------------------------------------------------------------------ + # 2. pending_moves. Snapshot pre-move (in-place restore on revert). + # ------------------------------------------------------------------ for move in pending_moves: + source = str(move.get("source") or "") + if snapshot_enabled and source: + tcid = str(move.get("tool_call_id") or "") + action_id = _action_id_for(tcid) + if action_id is not None: + # Resolve the doc to snapshot BEFORE we mutate it. + doc_id_pre = doc_id_by_path.get(source) + document_pre: Document | None = None + if doc_id_pre is not None: + res_pre = await session.execute( + select(Document).where( + Document.id == doc_id_pre, + Document.search_space_id == search_space_id, + ) + ) + document_pre = res_pre.scalar_one_or_none() + if document_pre is None: + document_pre = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=source, + ) + if document_pre is not None: + await _snapshot_document_pre_move( + session, + doc=document_pre, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) + applied = await _apply_move( session, search_space_id=search_space_id, @@ -431,8 +935,13 @@ async def commit_staged_filesystem_state( path = move_alias[path] return path + # ------------------------------------------------------------------ + # 3. dirty_paths -> writes/edits. Skip any path queued for ``rm`` + # this turn so a write+rm sequence doesn't recreate the doc. + # ------------------------------------------------------------------ kb_dirty_seen: set[str] = set() kb_dirty: list[str] = [] + kb_dirty_origin: dict[str, str] = {} for raw in dirty_paths: if not isinstance(raw, str): continue @@ -441,8 +950,12 @@ async def commit_staged_filesystem_state( continue if final in kb_dirty_seen: continue + if final in file_delete_paths: + discarded.append(final) + continue kb_dirty_seen.add(final) kb_dirty.append(final) + kb_dirty_origin[final] = raw for path in kb_dirty: basename = _basename(path) @@ -454,6 +967,15 @@ async def commit_staged_filesystem_state( continue content = "\n".join(file_data.get("content") or []) doc_id = doc_id_by_path.get(path) + # Path ↔ tool_call_id binding: the dirty_paths list dedupes via + # _add_unique_reducer, so we look up the latest tool_call_id by + # path (or by the un-renamed origin). + origin = kb_dirty_origin.get(path, path) + tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get( + origin + ) + action_id = _action_id_for(tcid) + if doc_id is None: # The in-memory ``doc_id_by_path`` is per-thread and starts # empty in every new chat. If the agent writes to a path @@ -470,6 +992,23 @@ async def commit_staged_filesystem_state( doc_id = existing.id doc_id_by_path[path] = existing.id if doc_id is not None: + if snapshot_enabled and action_id is not None: + result_doc = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + existing_doc = result_doc.scalar_one_or_none() + if existing_doc is not None: + await _snapshot_document_pre_write( + session, + doc=existing_doc, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) updated = await _update_document( session, doc_id=doc_id, @@ -492,12 +1031,21 @@ async def commit_staged_filesystem_state( } ) else: - # Wrap each create in a SAVEPOINT so a residual - # ``IntegrityError`` (e.g. a deployment that hasn't run - # migration 133 yet, where ``documents.content_hash`` - # still carries its legacy global UNIQUE constraint) - # rolls back only this one create instead of poisoning - # the whole turn's transaction. + # Fresh create. Wrap each create in a SAVEPOINT so a + # residual ``IntegrityError`` (e.g. a deployment that + # hasn't run migration 133 yet, where + # ``documents.content_hash`` still carries its legacy + # global UNIQUE constraint) rolls back only this one + # create instead of poisoning the whole turn. + placeholder_revision_id: int | None = None + if snapshot_enabled and action_id is not None: + placeholder_revision_id = await _snapshot_document_pre_create( + session, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) try: async with session.begin_nested(): new_doc = await _create_document( @@ -511,14 +1059,16 @@ async def commit_staged_filesystem_state( logger.warning( "kb_persistence: skipping %s create: %s", path, exc ) + # Roll back the placeholder revision since the create + # never happened. + if placeholder_revision_id is not None: + await session.execute( + delete(DocumentRevision).where( + DocumentRevision.id == placeholder_revision_id + ) + ) continue except IntegrityError as exc: - # The path-uniqueness check above already protected - # against ``unique_identifier_hash`` collisions, so - # the most likely culprit is the legacy - # ``ix_documents_content_hash`` UNIQUE constraint - # that migration 133 drops. Log loudly so operators - # know to run the migration; do NOT silently swallow. msg = str(exc.orig) if exc.orig is not None else str(exc) logger.error( "kb_persistence: IntegrityError creating %s: %s. " @@ -528,8 +1078,20 @@ async def commit_staged_filesystem_state( path, msg, ) + if placeholder_revision_id is not None: + await session.execute( + delete(DocumentRevision).where( + DocumentRevision.id == placeholder_revision_id + ) + ) continue doc_id_by_path[path] = new_doc.id + if placeholder_revision_id is not None: + await session.execute( + update(DocumentRevision) + .where(DocumentRevision.id == placeholder_revision_id) + .values(document_id=new_doc.id) + ) committed_creates.append( { "id": new_doc.id, @@ -545,13 +1107,234 @@ async def commit_staged_filesystem_state( ) tree_changed = True + # ------------------------------------------------------------------ + # 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE + # share a SAVEPOINT. If the snapshot insert fails, the DELETE + # rolls back too and we surface the error rather than silently + # making the data irreversible. + # ------------------------------------------------------------------ + for raw_path, tcid in file_delete_paths.items(): + final = _final_path(raw_path) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + action_id = _action_id_for(tcid) + + # Resolve the doc. + doc_id_for_delete = doc_id_by_path.get(final) + document_to_delete: Document | None = None + if doc_id_for_delete is not None: + result = await session.execute( + select(Document).where( + Document.id == doc_id_for_delete, + Document.search_space_id == search_space_id, + ) + ) + document_to_delete = result.scalar_one_or_none() + if document_to_delete is None: + document_to_delete = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=final, + ) + if document_to_delete is None: + logger.info( + "kb_persistence: skipping rm %s (target not found)", final + ) + continue + + doc_pk = document_to_delete.id + doc_title = document_to_delete.title + doc_folder_id = document_to_delete.folder_id + + try: + async with session.begin_nested(): + # Strict: snapshot first; failure aborts the delete. + if snapshot_enabled and action_id is not None: + chunks = await _load_chunks_for_snapshot( + session, doc_id=doc_pk + ) + payload = _doc_revision_payload( + document_to_delete, chunks_before=chunks + ) + rev = DocumentRevision( + document_id=doc_pk, + search_space_id=search_space_id, + created_by_turn_id=tcid, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + await session.execute( + delete(Document).where(Document.id == doc_pk) + ) + except Exception as exc: + logger.exception( + "kb_persistence: strict rm SAVEPOINT for path=%s failed: %s", + final, + exc, + ) + continue + + # B1 — SAVEPOINT released. Defer the reversibility-flip + # dispatch until AFTER the outer commit succeeds so we + # never tell the UI a row is reversible if its snapshot + # gets rolled back. + if snapshot_enabled and action_id is not None: + deferred_dispatches.append(int(action_id)) + + doc_id_by_path.pop(final, None) + doc_id_path_tombstones[final] = None + committed_deletes.append( + { + "id": doc_pk, + "title": doc_title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": doc_folder_id, + "createdById": str(created_by_id) if created_by_id else None, + "virtualPath": final, + } + ) + tree_changed = True + + # ------------------------------------------------------------------ + # 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final + # emptiness check (after step 4's deletes have run, an "empty + # mid-turn" directory really IS empty in DB now). + # ------------------------------------------------------------------ + for raw_path, tcid in dir_delete_paths.items(): + final = _final_path(raw_path) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + action_id = _action_id_for(tcid) + + folder_parts = _split_folder_path(final) + if not folder_parts: + continue + folder_id = await _resolve_folder_id( + session, + search_space_id=search_space_id, + folder_parts=folder_parts, + ) + if folder_id is None: + logger.info( + "kb_persistence: skipping rmdir %s (folder not found)", final + ) + continue + + # Re-check emptiness against in-DB state. + docs_in_folder = await session.execute( + select(Document.id) + .where(Document.folder_id == folder_id) + .where(Document.search_space_id == search_space_id) + .limit(1) + ) + if docs_in_folder.scalar_one_or_none() is not None: + logger.warning( + "kb_persistence: refusing rmdir %s — non-empty at commit time", + final, + ) + continue + child_folders = await session.execute( + select(Folder.id) + .where(Folder.parent_id == folder_id) + .where(Folder.search_space_id == search_space_id) + .limit(1) + ) + if child_folders.scalar_one_or_none() is not None: + logger.warning( + "kb_persistence: refusing rmdir %s — has child folders " + "at commit time", + final, + ) + continue + + folder_to_delete_res = await session.execute( + select(Folder).where(Folder.id == folder_id) + ) + folder_to_delete = folder_to_delete_res.scalar_one_or_none() + if folder_to_delete is None: + continue + + folder_pk = folder_to_delete.id + folder_name = folder_to_delete.name + folder_parent_id = folder_to_delete.parent_id + folder_position = folder_to_delete.position + + try: + async with session.begin_nested(): + if snapshot_enabled and action_id is not None: + rev = FolderRevision( + folder_id=folder_pk, + search_space_id=search_space_id, + name_before=folder_name, + parent_id_before=folder_parent_id, + position_before=folder_position, + created_by_turn_id=tcid, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + await session.execute( + delete(Folder).where(Folder.id == folder_pk) + ) + except Exception as exc: + logger.exception( + "kb_persistence: strict rmdir SAVEPOINT for path=%s failed: %s", + final, + exc, + ) + continue + + # B1 — SAVEPOINT released. Defer the reversibility-flip + # dispatch until AFTER the outer commit succeeds so we + # never tell the UI a row is reversible if its snapshot + # gets rolled back. + if snapshot_enabled and action_id is not None: + deferred_dispatches.append(int(action_id)) + + committed_folder_deletes.append( + { + "id": folder_pk, + "name": folder_name, + "searchSpaceId": search_space_id, + "parentId": folder_parent_id, + "virtualPath": final, + } + ) + tree_changed = True + await session.commit() except Exception: # pragma: no cover - rollback safety net logger.exception( "kb_persistence: commit failed (search_space=%s)", search_space_id ) + # Outer commit raised — every SAVEPOINT-released change above + # (snapshots + reversibility flips) is now rolled back. Drop + # the deferred SSE dispatches so the UI stays consistent with + # durable state. + deferred_dispatches.clear() return None + # Outer commit succeeded; flush deferred reversibility-flip + # dispatches now so the chat tool card can light up its Revert + # button without re-fetching ``GET /threads/.../actions``. De-dup + # to avoid emitting the same id twice (e.g. write-then-rm in the + # same turn dispatches once for each snapshot site). + if deferred_dispatches and dispatch_events: + for action_id in dict.fromkeys(deferred_dispatches): + try: + await _dispatch_reversibility_update(action_id) + except Exception: + logger.debug( + "kb_persistence: deferred reversibility dispatch failed for action_id=%s", + action_id, + exc_info=True, + ) + if dispatch_events: for payload in committed_creates: try: @@ -567,11 +1350,34 @@ async def commit_staged_filesystem_state( logger.exception( "kb_persistence: failed to dispatch document_updated event" ) + for payload in committed_deletes: + try: + dispatch_custom_event("document_deleted", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_deleted event" + ) + for payload in committed_folder_deletes: + try: + dispatch_custom_event("folder_deleted", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch folder_deleted event" + ) temp_paths = [ p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) ] + # Tombstone every committed-delete path so a stale ``state["files"]`` entry + # (which als_info would otherwise interpret as content) cannot survive into + # the next turn and make a now-empty folder look non-empty. + deleted_file_paths = [ + str(payload.get("virtualPath") or "") + for payload in committed_deletes + if payload.get("virtualPath") + ] + doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones} for payload in committed_creates: doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"]) @@ -579,23 +1385,38 @@ async def commit_staged_filesystem_state( delta: dict[str, Any] = { "dirty_paths": [_CLEAR], "staged_dirs": [_CLEAR], + "staged_dir_tool_calls": {_CLEAR: True}, "pending_moves": [_CLEAR], + "pending_deletes": [_CLEAR], + "pending_dir_deletes": [_CLEAR], + "dirty_path_tool_calls": {_CLEAR: True}, } + files_delta: dict[str, Any] = {} if temp_paths: - delta["files"] = dict.fromkeys(temp_paths) + files_delta.update(dict.fromkeys(temp_paths)) + for path in deleted_file_paths: + files_delta[path] = None + if files_delta: + delta["files"] = files_delta if doc_id_update: delta["doc_id_by_path"] = doc_id_update if tree_changed: delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 + # Avoid 'unused' lint when turn_id_for_revision was only useful for + # diagnostic purposes inside the SAVEPOINT chain above. + _ = turn_id_for_revision + logger.info( "kb_persistence: commit (search_space=%s) creates=%d updates=%d " - "moves=%d staged_dirs=%d discarded=%d", + "moves=%d staged_dirs=%d deletes=%d folder_deletes=%d discarded=%d", search_space_id, len(committed_creates), len(committed_updates), len(applied_moves), len(staged_dirs), + len(committed_deletes), + len(committed_folder_deletes), len(discarded), ) return delta @@ -618,10 +1439,12 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type- search_space_id: int, created_by_id: str | None, filesystem_mode: FilesystemMode, + thread_id: int | None = None, ) -> None: self.search_space_id = search_space_id self.created_by_id = created_by_id self.filesystem_mode = filesystem_mode + self.thread_id = thread_id async def aafter_agent( # type: ignore[override] self, @@ -636,6 +1459,7 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type- search_space_id=self.search_space_id, created_by_id=self.created_by_id, filesystem_mode=self.filesystem_mode, + thread_id=self.thread_id, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py index ddb2d4af1..7cf3bf8cd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py @@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol): def _pending_moves(self) -> list[dict[str, Any]]: return list(self.state.get("pending_moves") or []) + def _pending_deletes(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_deletes") or []) + + def _pending_dir_deletes(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_dir_deletes") or []) + def _kb_anon_doc(self) -> dict[str, Any] | None: anon = self.state.get("kb_anon_doc") return anon if isinstance(anon, dict) else None @@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol): return path return path.rstrip("/") if path != "/" else path - def _moved_view_paths( + def _pending_filesystem_view( self, existing: dict[str, dict[str, Any]], - ) -> tuple[set[str], dict[str, str]]: - """Apply ``pending_moves`` to a path set and return ``(removed, alias)``. + ) -> tuple[set[str], dict[str, str], set[str]]: + """Compute removed/aliased/dir-suppressed paths from staged ops. - Removed paths should disappear from listings; ``alias[source] = dest`` - means a virtual entry should appear at ``dest`` even if no DB row is - yet there. + Returns ``(removed, alias, deleted_dirs)`` where: + + * ``removed`` — paths to drop from listings (sources of pending moves + AND paths queued for ``rm``). + * ``alias`` — ``{source: dest}`` for pending moves; the dest should + appear as a virtual entry even when no DB row is at that path yet. + * ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire + subtree (descendants) is suppressed from listings/glob/grep. + + Entries in ``existing`` (the ``files`` state cache) keyed by a + removed path are popped so a same-turn delete-after-write doesn't + leave a stale virtual file in listings. """ removed: set[str] = set() alias: dict[str, str] = {} + deleted_dirs: set[str] = set() for move in self._pending_moves(): src = move.get("source") dst = move.get("dest") @@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol): removed.add(src) alias[src] = dst existing.pop(src, None) - return removed, alias + for entry in self._pending_deletes(): + path = entry.get("path") if isinstance(entry, dict) else None + if not path: + continue + removed.add(path) + existing.pop(path, None) + for entry in self._pending_dir_deletes(): + path = entry.get("path") if isinstance(entry, dict) else None + if not path: + continue + deleted_dirs.add(path) + return removed, alias, deleted_dirs + + @staticmethod + def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool: + """Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``.""" + return any(path == d or _is_under(path, d) for d in deleted_dirs) # ------------------------------------------------------------------ ls/read @@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol): seen.add(anon_path) files = self._state_files() - moved_removed, moved_alias = self._moved_view_paths(files) + moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files) if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": try: @@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol): for info in db_infos: p = info.get("path", "") - if not p or p in seen or p in moved_removed: + if ( + not p + or p in seen + or p in moved_removed + or self._is_dir_suppressed(p, deleted_dirs) + ): continue infos.append(info) seen.add(p) @@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol): if src not in seen: if not _is_under(dst, normalized): continue + if self._is_dir_suppressed(dst, deleted_dirs): + continue rel = ( dst[len(normalized) :].lstrip("/") if normalized != "/" @@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol): continue if not _is_under(staged, normalized): continue + if self._is_dir_suppressed(staged, deleted_dirs): + continue rel = ( staged[len(normalized) :].lstrip("/") if normalized != "/" @@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol): for sub in sorted(subdir_paths): if sub in seen: continue + if self._is_dir_suppressed(sub, deleted_dirs): + continue infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at="")) seen.add(sub) for path_key, fd in files.items(): if not isinstance(path_key, str) or path_key in seen: continue + # Tombstones (None values) are deletion markers from `rm`. The + # deepagents reducer normally pops them, but a stale tombstone + # surviving a checkpoint must NOT be reported as a child here — + # otherwise rmdir mistakenly sees the deleted file as content. + if fd is None: + continue if not _is_under(path_key, normalized) or path_key == normalized: continue + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): + continue if normalized == "/": rel = path_key.lstrip("/") else: @@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol): seen: set[str] = set() files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) regex = re.compile(fnmatch.translate(pattern)) for path_key, fd in files.items(): - if path_key in moved_removed: + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): continue if not _is_under(path_key, normalized): continue @@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol): folder_id=row.folder_id, index=index, ) - if candidate in seen or candidate in moved_removed: + if ( + candidate in seen + or candidate in moved_removed + or self._is_dir_suppressed(candidate, deleted_dirs) + ): continue if not _is_under(candidate, normalized): continue @@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol): matches: list[GrepMatch] = [] files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) glob_re = re.compile(fnmatch.translate(glob)) if glob else None for path_key, fd in files.items(): - if path_key in moved_removed: + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): continue if not _is_under(path_key, normalized): continue @@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol): ) for doc_id, chunk_id, content in chunk_buffer: candidate = doc_id_to_path.get(doc_id) - if not candidate or candidate in moved_removed: + if ( + not candidate + or candidate in moved_removed + or self._is_dir_suppressed(candidate, deleted_dirs) + ): continue if not _is_under(candidate, normalized): continue @@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol): return {"entries": [], "truncated": False} files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) anon = self._kb_anon_doc() anon_path = str(anon.get("path") or "") if anon else "" @@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol): for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]): if not _is_under(fpath, normalized): continue + if self._is_dir_suppressed(fpath, deleted_dirs): + continue depth = _depth_of(fpath) if max_depth is not None and depth > max_depth: continue @@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol): for staged in self._staged_dirs(): if not _is_under(staged, normalized): continue + if self._is_dir_suppressed(staged, deleted_dirs): + continue depth = _depth_of(staged) if max_depth is not None and depth > max_depth: continue @@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol): folder_id=row.folder_id, index=index, ) - if candidate in moved_removed: + if candidate in moved_removed or self._is_dir_suppressed( + candidate, deleted_dirs + ): continue if not _is_under(candidate, normalized): continue @@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol): continue if not _is_under(path_key, normalized): continue + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): + continue if any(e["path"] == path_key for e in entries): continue if not ( diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py index 467d19747..e67be8221 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py @@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] ) all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT])) + # Pre-compute which folders have at least one descendant (folder or doc). + # A folder is "empty" iff no path in `all_paths` is strictly under it. + # Used to emit an explicit "(empty)" marker so the LLM doesn't have to + # infer emptiness from indentation alone. + non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths) + lines: list[str] = [] for path in all_paths: depth = ( @@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents" ) if is_dir: - lines.append(f"{indent}{display}/") + if path != DOCUMENTS_ROOT and path not in non_empty_folders: + lines.append(f"{indent}{display}/ (empty)") + else: + lines.append(f"{indent}{display}/") else: lines.append(f"{indent}{display}") if len(lines) >= self.max_entries: @@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] return self._format_root_summary(folder_paths, doc_paths) + @staticmethod + def _compute_non_empty_folders( + folder_paths: list[str], doc_paths: list[str] + ) -> set[str]: + """Return the set of folder paths that contain at least one descendant. + + A folder is "non-empty" if any document path or any other folder path + is strictly under it. Documents propagate emptiness up to every + ancestor folder, while a sub-folder only marks its direct ancestors + non-empty (so a chain of empty folders all read ``(empty)``). + """ + non_empty: set[str] = set() + folder_set = set(folder_paths) + + for doc_path in doc_paths: + parent = doc_path.rsplit("/", 1)[0] + while parent and parent != DOCUMENTS_ROOT: + if parent in folder_set: + non_empty.add(parent) + parent = parent.rsplit("/", 1)[0] + + for child in folder_paths: + parent = child.rsplit("/", 1)[0] + while parent and parent != DOCUMENTS_ROOT and parent in folder_set: + non_empty.add(parent) + parent = parent.rsplit("/", 1)[0] + + return non_empty + def _format_root_summary( self, folder_paths: list[str], doc_paths: list[str] ) -> str: diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 565fcb48b..4db9943cb 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -360,6 +360,74 @@ class LocalFolderBackend: self.move, source_path, destination_path, overwrite ) + def delete_file(self, file_path: str) -> WriteResult: + """Hard-delete a single file under root. + + Refuses directories, root, and missing paths. Roughly mirrors POSIX + ``rm path``; ``-r`` recursion and glob expansion are explicitly + out of scope. + """ + try: + path = self._resolve_virtual(file_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{file_path}'") + with self._lock_for(file_path): + if not path.exists(): + return WriteResult(error=f"Error: File '{file_path}' not found") + if path.is_dir(): + return WriteResult( + error=( + f"Error: '{file_path}' is a directory. " + "Use rmdir for empty directories." + ) + ) + try: + os.unlink(path) + except OSError as exc: + return WriteResult( + error=f"Error: failed to delete '{file_path}': {exc}" + ) + return WriteResult(path=file_path, files_update=None) + + async def adelete_file(self, file_path: str) -> WriteResult: + return await asyncio.to_thread(self.delete_file, file_path) + + def rmdir(self, dir_path: str) -> WriteResult: + """Hard-delete an empty directory under root. + + Refuses files, root, missing paths, and non-empty directories. + ``os.rmdir`` is naturally empty-only; we pre-check so the error is + clearer for the agent. + """ + try: + path = self._resolve_virtual(dir_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{dir_path}'") + with self._lock_for(dir_path): + if not path.exists(): + return WriteResult(error=f"Error: Directory '{dir_path}' not found") + if not path.is_dir(): + return WriteResult(error=f"Error: '{dir_path}' is not a directory") + try: + next(path.iterdir()) + except StopIteration: + pass + else: + return WriteResult( + error=( + f"Error: directory '{dir_path}' is not empty. " + "Remove its contents first." + ) + ) + try: + os.rmdir(path) + except OSError as exc: + return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}") + return WriteResult(path=dir_path, files_update=None) + + async def armdir(self, dir_path: str) -> WriteResult: + return await asyncio.to_thread(self.rmdir, dir_path) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 93eabe6ff..a5add6248 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend: overwrite, ) + def delete_file(self, file_path: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + result = self._mount_to_backend[mount].delete_file(local_path) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def adelete_file(self, file_path: str) -> WriteResult: + return await asyncio.to_thread(self.delete_file, file_path) + + def rmdir(self, dir_path: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(dir_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + if local_path == "/": + return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'") + result = self._mount_to_backend[mount].rmdir(local_path) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def armdir(self, dir_path: str) -> WriteResult: + return await asyncio.to_thread(self.rmdir, dir_path) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/new_chat/state_reducers.py index ce32406e6..89fc86367 100644 --- a/surfsense_backend/app/agents/new_chat/state_reducers.py +++ b/surfsense_backend/app/agents/new_chat/state_reducers.py @@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]: return { "cwd": "/documents", "staged_dirs": [], + "staged_dir_tool_calls": {}, "pending_moves": [], + "pending_deletes": [], + "pending_dir_deletes": [], "doc_id_by_path": {}, "dirty_paths": [], + "dirty_path_tool_calls": {}, "kb_priority": [], "kb_matched_chunk_ids": {}, "kb_anon_doc": None, diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py index b36d35fa0..84ca516e0 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/config.py +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -84,6 +84,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( "write_file", "move_file", "mkdir", + "rm", + "rmdir", "update_memory", "update_memory_team", "update_memory_private", diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 8480e57b1..92248c2c9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -30,6 +30,35 @@ from langgraph.types import interrupt logger = logging.getLogger(__name__) +# Tools that mirror the safety profile of ``write_file`` against the +# SurfSense KB: each call creates ONE artifact in the user's own workspace +# with no external visibility (drafts aren't sent; new files aren't shared +# unless the user shares them later). These are auto-approved by default +# so the agent can compose drafts and seed scratch files without a popup +# on every call. +# +# Members of this set still call ``request_approval`` exactly as before; +# the function returns immediately with ``decision_type="auto_approved"`` +# and the original params untouched. This preserves the call-site shape +# (logging, metadata fetching, account fallbacks) so the only behavior +# change is "no interrupt fires". +# +# To re-enable prompting, the future per-search-space rules table +# (``agent_permission_rules``) takes precedence — see the ``# (future)`` +# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`. +DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( + { + "create_gmail_draft", + "update_gmail_draft", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } +) + + @dataclass(frozen=True, slots=True) class HITLResult: """Outcome of a human-in-the-loop approval request.""" @@ -119,6 +148,19 @@ def request_approval( logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name) return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) + if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: + # Default policy: low-stakes creation tools (drafts + new-file + # creates) skip HITL because they're as recoverable as a local + # ``write_file`` against the SurfSense KB. The user can still + # delete the artifact in <30s if it's wrong. + logger.info( + "Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", + tool_name, + ) + return HITLResult( + rejected=False, decision_type="auto_approved", params=dict(params) + ) + approval = interrupt( { "type": action_type, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 75342a8e1..91d19fb4f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -689,6 +689,12 @@ class NewChatMessage(BaseModel, TimestampMixin): index=True, ) + # Per-turn correlation id sourced from ``configurable.turn_id`` at + # streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows + # predate the column. Used by C1's edit-from-arbitrary-position to map + # a message back to the LangGraph checkpoint that produced its turn. + turn_id = Column(String(64), nullable=True, index=True) + # Relationships thread = relationship("NewChatThread", back_populates="messages") author = relationship("User") @@ -2292,7 +2298,13 @@ class AgentActionLog(BaseModel): nullable=False, index=True, ) + # ``turn_id`` historically held the LangChain ``tool_call.id``. It has + # been renamed to ``tool_call_id`` (with a parallel column kept for one + # release for back-compat). The real chat-turn id lives in + # ``chat_turn_id`` and is sourced from ``configurable.turn_id``. turn_id = Column(String(64), nullable=True, index=True) + tool_call_id = Column(String(64), nullable=True, index=True) + chat_turn_id = Column(String(64), nullable=True, index=True) message_id = Column(String(128), nullable=True, index=True) tool_name = Column(String(255), nullable=False, index=True) args = Column(JSONB, nullable=True) @@ -2318,6 +2330,16 @@ class AgentActionLog(BaseModel): __table_args__ = ( Index("ix_agent_action_log_thread_created", "thread_id", "created_at"), + # Partial unique index enforces "at most one revert per + # original action". Created in migration 137 with + # ``WHERE reverse_of IS NOT NULL`` so non-revert rows + # (the vast majority) are unaffected and NULLs don't collide. + Index( + "ux_agent_action_log_reverse_of", + "reverse_of", + unique=True, + postgresql_where=text("reverse_of IS NOT NULL"), + ), ) @@ -2332,10 +2354,13 @@ class DocumentRevision(BaseModel): __tablename__ = "document_revisions" + # ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the + # hard-delete it describes — without that, ``rm`` would wipe the row + # we'd need to undo it. See migration ``134_relax_revision_fks``. document_id = Column( Integer, - ForeignKey("documents.id", ondelete="CASCADE"), - nullable=False, + ForeignKey("documents.id", ondelete="SET NULL"), + nullable=True, index=True, ) search_space_id = Column( @@ -2370,10 +2395,13 @@ class FolderRevision(BaseModel): __tablename__ = "folder_revisions" + # ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the + # hard-delete it describes — without that, ``rmdir`` would wipe the + # row we'd need to undo it. See migration ``134_relax_revision_fks``. folder_id = Column( Integer, - ForeignKey("folders.id", ondelete="CASCADE"), - nullable=False, + ForeignKey("folders.id", ondelete="SET NULL"), + nullable=True, index=True, ) search_space_id = Column( diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py index 458635761..2608aa3b1 100644 --- a/surfsense_backend/app/routes/agent_action_log_route.py +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -65,6 +65,13 @@ class AgentActionRead(BaseModel): reverse_of: int | None reverted_by_action_id: int | None is_revert_action: bool + # Correlation ids added in migration 135. ``tool_call_id`` is the + # LangChain tool-call id (joinable to ``data-action-log`` SSE events + # via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id + # from ``configurable.turn_id`` (used by the + # ``revert-turn/{chat_turn_id}`` endpoint). + tool_call_id: str | None = None + chat_turn_id: str | None = None created_at: datetime @@ -172,6 +179,8 @@ async def list_thread_actions( reverse_of=row.reverse_of, reverted_by_action_id=revert_map.get(row.id), is_revert_action=row.reverse_of is not None, + tool_call_id=row.tool_call_id, + chat_turn_id=row.chat_turn_id, created_at=row.created_at, ) for row in rows diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index 12484ff53..711081b15 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs: 4. Revert dispatch via :func:`app.services.revert_service.revert_action`. 5. Idempotent on retries: if the same action is reverted twice the second call returns 409 ``"already reverted"``. + +This module also hosts the per-turn batch endpoint +``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It +walks every reversible action emitted during a chat turn in reverse +``created_at`` order and reverts each independently. Partial success is the +common case — the response always contains a per-action result list and a +``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a +whole-batch 4xx. """ from __future__ import annotations import logging +from typing import Literal from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.feature_flags import get_flags @@ -97,6 +108,16 @@ async def revert_agent_action( action=action, requester_user_id=str(user.id) if user is not None else None, ) + except IntegrityError: + # Partial unique index ``ux_agent_action_log_reverse_of`` caught + # a concurrent revert. Translate to the existing 409 "already + # reverted" contract so racing clients see consistent + # behaviour with the pre-flight TOCTOU check above. + await session.rollback() + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) from None except Exception as err: logger.exception("Revert dispatch raised for action_id=%s", action_id) await session.rollback() @@ -105,7 +126,16 @@ async def revert_agent_action( ) from err if outcome.status == "ok": - await session.commit() + try: + await session.commit() + except IntegrityError: + # Race lost on commit (constraint enforced at flush in some + # configs but at commit in others — defensive). + await session.rollback() + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) from None return { "status": "ok", "message": outcome.message, @@ -122,3 +152,357 @@ async def revert_agent_action( raise HTTPException(status_code=501, detail=outcome.message) # not_reversible raise HTTPException(status_code=409, detail=outcome.message) + + +# --------------------------------------------------------------------------- +# Per-turn revert batch endpoint +# --------------------------------------------------------------------------- + + +PerActionStatus = Literal[ + "reverted", + "already_reverted", + "not_reversible", + "permission_denied", + "failed", + "skipped", +] + + +class RevertTurnActionResult(BaseModel): + """Per-action outcome inside a ``revert-turn`` batch response.""" + + action_id: int + tool_name: str + status: PerActionStatus + message: str | None = None + new_action_id: int | None = None + error: str | None = None + + +class RevertTurnResponse(BaseModel): + """Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + + ``status`` is ``"ok"`` only when every reversible row succeeded. Any + ``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades + it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty + ``results`` list — callers should treat that as a no-op. + + Counter invariant: + ``total == reverted + already_reverted + not_reversible + + permission_denied + failed + skipped`` + + Frontend toasts and the ``RevertTurnButton`` summary rely on this + invariant to display "X of Y reverted, Z could not be undone" without + silently dropping ``permission_denied`` or ``skipped`` rows. + """ + + status: Literal["ok", "partial"] + chat_turn_id: str + total: int + reverted: int + already_reverted: int + not_reversible: int + permission_denied: int = 0 + failed: int = 0 + skipped: int = 0 + results: list[RevertTurnActionResult] + + +def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus: + if outcome.status == "ok": + return "reverted" + if outcome.status == "permission_denied": + return "permission_denied" + # ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` / + # ``not_reversible`` are all surfaced to the caller as "not_reversible" + # — they share the same UX (this row cannot be undone) and only the + # ``message`` differs. + return "not_reversible" + + +async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None: + """Return the id of an existing successful revert row, if any. + + Single-action variant — kept for the post-IntegrityError lookup + path where we already know we lost a race for one specific id. + """ + stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id) + result = await session.execute(stmt) + return result.scalars().first() + + +async def _was_already_reverted_batch( + session: AsyncSession, *, action_ids: list[int] +) -> dict[int, int]: + """Batch idempotency probe for the revert-turn loop. + + Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries + (one per row in the turn) with a single ``SELECT id, reverse_of + WHERE reverse_of IN (:ids)``. The route still iterates rows in + reverse-chronological order, but the membership check is O(1) per + iteration after this query. For a turn with 30 actions that's 30 + fewer round-trips through asyncpg + a smaller transaction footprint. + + Returns a ``{original_action_id -> revert_action_id}`` map. Missing + keys mean "not yet reverted" — callers should treat them as + eligible for revert. + """ + if not action_ids: + return {} + stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where( + AgentActionLog.reverse_of.in_(action_ids) + ) + result = await session.execute(stmt) + return { + original_id: revert_id + for revert_id, original_id in result.all() + if original_id is not None + } + + +@router.post( + "/threads/{thread_id}/revert-turn/{chat_turn_id}", + response_model=RevertTurnResponse, +) +async def revert_agent_turn( + thread_id: int, + chat_turn_id: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> RevertTurnResponse: + """Revert every reversible action emitted during ``chat_turn_id``. + + Walks ``AgentActionLog`` rows for the turn in reverse ``created_at`` + order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new + folder) unwind in the right sequence. Each action is reverted in its + own SAVEPOINT so a single failure does not poison the batch. + + Partial success is intentional and returned with HTTP 200. Callers + must inspect ``results[*].status`` to find rows that need attention. + """ + + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_revert_route: + raise HTTPException( + status_code=503, + detail=( + "Revert is not available on this deployment yet. The route " + "ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to " + "enable it." + ), + ) + + thread = await load_thread(session, thread_id=thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + # Reverse-chronological so the latest mutation in the turn unwinds + # first. ``id.desc()`` is the deterministic tiebreaker for actions + # written in the same millisecond. + rows_stmt = ( + select(AgentActionLog) + .where( + AgentActionLog.thread_id == thread_id, + AgentActionLog.chat_turn_id == chat_turn_id, + ) + .order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc()) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + requester_user_id = str(user.id) if user is not None else None + results: list[RevertTurnActionResult] = [] + # Counters MUST be exhaustive so the response invariant + # ``total == sum(counters)`` always holds. Frontend toasts and + # ``RevertTurnButton`` rely on this for "X of Y reverted" math. + counts: dict[str, int] = { + "reverted": 0, + "already_reverted": 0, + "not_reversible": 0, + "permission_denied": 0, + "failed": 0, + "skipped": 0, + } + + # Single batched idempotency probe replaces the previous per-row + # SELECT. ``rows`` are filtered in the loop so we pre-collect only + # the original-action ids (skip rows that are themselves + # reverts). + eligible_ids = [r.id for r in rows if r.reverse_of is None] + already_reverted_map = await _was_already_reverted_batch( + session, action_ids=eligible_ids + ) + + for action in rows: + # Skip rows that ARE reverts of an earlier action — reverting a + # revert is meaningless inside a batch (the user wants to wipe + # the original effects, not chase tail). + if action.reverse_of is not None: + counts["skipped"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="skipped", + message="Row is itself a revert action; skipped.", + ) + ) + continue + + # Idempotency: surface "already_reverted" instead of failing. + existing_revert_id = already_reverted_map.get(action.id) + if existing_revert_id is not None: + counts["already_reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ) + ) + continue + + if not can_revert( + requester_user_id=requester_user_id, + action=action, + is_admin=False, + ): + counts["permission_denied"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="permission_denied", + message="You are not allowed to revert this action.", + ) + ) + continue + + # Per-row SAVEPOINT so one failed revert never poisons later + # successful ones. + try: + async with session.begin_nested(): + outcome = await revert_action( + session, + action=action, + requester_user_id=requester_user_id, + ) + if outcome.status != "ok": + raise _OutcomeRollbackError(outcome) + except _OutcomeRollbackError as rollback: + outcome = rollback.outcome + classified = _classify_outcome(outcome) + if classified == "permission_denied": + counts["permission_denied"] += 1 + else: + counts["not_reversible"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status=classified, + message=outcome.message, + ) + ) + continue + except IntegrityError: + # Partial unique index caught a concurrent revert that won + # the race against our pre-flight ``_was_already_reverted`` + # SELECT. Look up the winner so + # we can surface its ``new_action_id`` to the client. + existing_revert_id = await _was_already_reverted( + session, action_id=action.id + ) + counts["already_reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ) + ) + continue + except Exception as err: # pragma: no cover — defensive, logged + logger.exception( + "Unexpected revert failure inside batch for action_id=%s", + action.id, + ) + counts["failed"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="failed", + error=str(err) or err.__class__.__name__, + ) + ) + continue + + counts["reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="reverted", + message=outcome.message, + new_action_id=outcome.new_action_id, + ) + ) + + # Single commit at the end — successful SAVEPOINTs above already + # released; failed ones rolled back to their savepoint. No row leaks + # across the boundary. + try: + await session.commit() + except Exception as err: # pragma: no cover — defensive + logger.exception( + "Final commit for revert-turn failed (thread=%s turn=%s)", + thread_id, + chat_turn_id, + ) + await session.rollback() + raise HTTPException( + status_code=500, + detail="Internal error while finalising revert-turn batch.", + ) from err + + has_partial = ( + counts["failed"] > 0 + or counts["not_reversible"] > 0 + or counts["permission_denied"] > 0 + ) + overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok" + + return RevertTurnResponse( + status=overall_status, + chat_turn_id=chat_turn_id, + total=len(rows), + reverted=counts["reverted"], + already_reverted=counts["already_reverted"], + not_reversible=counts["not_reversible"], + permission_denied=counts["permission_denied"], + failed=counts["failed"], + skipped=counts["skipped"], + results=results, + ) + + +class _OutcomeRollbackError(Exception): + """Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome. + + ``revert_action`` writes a new ``agent_action_log`` row only on the + happy path, but on the failure paths it sometimes mutates the + ``DocumentRevision``/``Document`` tables before deciding the action + is not reversible. Wrapping each call in ``begin_nested`` and raising + this from the failure branch ensures we always discard partial + writes for failed rows. + """ + + def __init__(self, outcome: RevertOutcome) -> None: + self.outcome = outcome + super().__init__(outcome.message) + + +__all__ = ["router"] diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5560d90d..26c72bd45 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui: """ import asyncio +import json import logging from datetime import UTC, datetime @@ -136,6 +137,260 @@ def _resolve_filesystem_selection( ) +def _find_pre_turn_checkpoint_id( + checkpoint_tuples: list, + *, + turn_id: str, +) -> str | None: + """Locate the LangGraph checkpoint immediately before ``turn_id`` started. + + ``checkpoint_tuples`` arrives newest-first from + ``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``) + and remember the most recent checkpoint that does NOT belong to the + edited turn. As soon as we cross into the edited turn (a checkpoint + whose ``turn_id`` matches), we return the previously-tracked + checkpoint — that's the state immediately before ``turn_id`` began. + + The naive "newest-first, return first non-matching" approach is + INCORRECT when later turns exist after ``turn_id``: their + checkpoints also satisfy ``cp_turn_id != turn_id`` and would be + returned before the real pre-turn boundary is reached. + + Reads from ``cp_tuple.metadata`` (the durable surface promoted from + ``configurable`` at write time) rather than ``config["configurable"]`` + so the lookup is portable across checkpointer implementations. + + Returns ``None`` when no eligible pre-turn checkpoint exists (e.g. + the edited turn is the very first turn of the thread). Callers fall + back to the oldest available checkpoint in that case. + """ + + last_pre_turn_target: str | None = None + for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest + metadata = getattr(cp_tuple, "metadata", None) or {} + cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None + if cp_turn_id == turn_id: + # Crossed into the edited turn; the previous tracked + # checkpoint is the rewind target. May be ``None`` if we hit + # the edited turn on the very first iteration. + return last_pre_turn_target + try: + last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"] + except (KeyError, TypeError): + continue + return last_pre_turn_target + + +async def _revert_turns_for_regenerate( + *, + thread_id: int, + chat_turn_ids: list[str], + requester_user_id: str, +) -> dict: + """Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``. + + Runs BEFORE the regenerate stream so the frontend can surface + partial-rollback feedback alongside the new assistant turn. Each + turn's actions are reverted in their own SAVEPOINTs (handled + inside :mod:`app.routes.agent_revert_route`'s helpers) so a single + failure never poisons the batch. + + Sequencing inside the request: revert THEN regenerate. The + operation is NOT atomic and partial state IS surfaced — see the + plan's "Sequencing inside the request" note. + """ + + from app.routes.agent_revert_route import ( + RevertTurnActionResult, + _classify_outcome, + _OutcomeRollbackError, + _was_already_reverted, + _was_already_reverted_batch, + ) + from app.services.revert_service import ( + can_revert, + revert_action, + ) + + aggregated_results: list[dict] = [] + # Exhaustive counters keep the response invariant + # ``total == sum(counters)`` true for ``data-revert-results``. + counts = { + "reverted": 0, + "already_reverted": 0, + "not_reversible": 0, + "permission_denied": 0, + "failed": 0, + "skipped": 0, + } + + # Local import keeps the route module's existing imports tidy and + # avoids a circular dependency at module-load time. + from app.db import AgentActionLog as _AgentActionLog + + async with shielded_async_session() as session: + for chat_turn_id in chat_turn_ids: + rows_stmt = ( + select(_AgentActionLog) + .where( + _AgentActionLog.thread_id == thread_id, + _AgentActionLog.chat_turn_id == chat_turn_id, + ) + .order_by( + _AgentActionLog.created_at.desc(), + _AgentActionLog.id.desc(), + ) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + # Batch idempotency probe across the turn (single SELECT + # instead of one per row). + eligible_ids = [r.id for r in rows if r.reverse_of is None] + already_reverted_map = await _was_already_reverted_batch( + session, action_ids=eligible_ids + ) + + for action in rows: + if action.reverse_of is not None: + counts["skipped"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="skipped", + message="Row is itself a revert action; skipped.", + ).model_dump() + ) + continue + + existing_revert_id = already_reverted_map.get(action.id) + if existing_revert_id is not None: + counts["already_reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ).model_dump() + ) + continue + + if not can_revert( + requester_user_id=requester_user_id, + action=action, + is_admin=False, + ): + counts["permission_denied"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="permission_denied", + message="You are not allowed to revert this action.", + ).model_dump() + ) + continue + + try: + async with session.begin_nested(): + outcome = await revert_action( + session, + action=action, + requester_user_id=requester_user_id, + ) + if outcome.status != "ok": + raise _OutcomeRollbackError(outcome) + except _OutcomeRollbackError as rollback: + outcome = rollback.outcome + classified = _classify_outcome(outcome) + if classified == "permission_denied": + counts["permission_denied"] += 1 + else: + counts["not_reversible"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status=classified, + message=outcome.message, + ).model_dump() + ) + continue + except IntegrityError: + # Concurrent revert won the race against the + # pre-flight ``_was_already_reverted`` SELECT. + # Surface the winning revert id so the client can + # treat this as a successful idempotent op. + existing_revert_id = await _was_already_reverted( + session, action_id=action.id + ) + counts["already_reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ).model_dump() + ) + continue + except Exception as err: # pragma: no cover — defensive + _logger.exception( + "Unexpected revert failure during regenerate batch " + "for action_id=%s", + action.id, + ) + counts["failed"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="failed", + error=str(err) or err.__class__.__name__, + ).model_dump() + ) + continue + + counts["reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="reverted", + message=outcome.message, + new_action_id=outcome.new_action_id, + ).model_dump() + ) + + try: + await session.commit() + except Exception: + _logger.exception( + "[regenerate-revert] Final commit failed; rolling back batch." + ) + await session.rollback() + + has_partial = ( + counts["failed"] > 0 + or counts["not_reversible"] > 0 + or counts["permission_denied"] > 0 + ) + + return { + "status": "partial" if has_partial else "ok", + "chat_turn_ids": chat_turn_ids, + "total": len(aggregated_results), + "reverted": counts["reverted"], + "already_reverted": counts["already_reverted"], + "not_reversible": counts["not_reversible"], + "permission_denied": counts["permission_denied"], + "failed": counts["failed"], + "skipped": counts["skipped"], + "results": aggregated_results, + } + + def _try_delete_sandbox(thread_id: int) -> None: """Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked.""" from app.agents.new_chat.sandbox import ( @@ -574,6 +829,7 @@ async def get_thread_messages( token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None, + turn_id=msg.turn_id, ) for msg in db_messages ] @@ -1006,12 +1262,24 @@ async def append_message( # Check thread-level access based on visibility await check_thread_access(session, thread, user) - # Create message + # Create message. ``turn_id`` is the per-turn correlation id from + # ``configurable.turn_id`` (added in migration 136) — when the + # client streams it back to ``appendMessage``, we persist it so + # C1's edit-from-arbitrary-position can later map this message + # back to the LangGraph checkpoint that produced its turn. + raw_turn_id = raw_body.get("turn_id") + turn_id_value = ( + str(raw_turn_id).strip() + if isinstance(raw_turn_id, str) and raw_turn_id.strip() + else None + ) + db_message = NewChatMessage( thread_id=thread_id, role=message_role, content=content, author_id=user.id, + turn_id=turn_id_value, ) session.add(db_message) @@ -1050,6 +1318,7 @@ async def append_message( created_at=db_message.created_at, author_id=db_message.author_id, token_usage=None, + turn_id=db_message.turn_id, ) except HTTPException: @@ -1373,43 +1642,123 @@ async def regenerate_response( user_query_to_use = request.user_query regenerate_image_urls: list[str] = [] - # Look through checkpoints to find the right one - # We want to find the checkpoint just before the last HumanMessage - for i, cp_tuple in enumerate(checkpoint_tuples): - # Access the checkpoint's channel_values which contains "messages" - checkpoint_data = cp_tuple.checkpoint - channel_values = checkpoint_data.get("channel_values", {}) - state_messages = channel_values.get("messages", []) + # --------------------------------------------------------------- + # Edit-from-arbitrary-position. When the client passes + # ``from_message_id`` we look up its persisted ``turn_id`` (added + # in migration 136) and pick the checkpoint immediately before + # that turn started. + # + # Legacy graceful-degradation contract: + # * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``. + # Returning 400 in that case is the wrong UX — the user is + # editing an old message in an existing thread and just wants + # it to work. We instead skip the checkpoint rewind (the + # stream falls back to the latest state) and skip the revert + # pass (no chat_turn_id available to walk). Deletion still + # uses ``created_at``, so the messages-after-cursor slice is + # correct on both legacy and post-136 rows. + # --------------------------------------------------------------- + from_message_turn_id: str | None = None + from_message_created_at: datetime | None = None + legacy_from_message: bool = False + if request.from_message_id is not None: + from_msg_row = await session.execute( + select(NewChatMessage).filter( + NewChatMessage.id == request.from_message_id, + NewChatMessage.thread_id == thread_id, + ) + ) + from_msg = from_msg_row.scalars().first() + if from_msg is None: + raise HTTPException( + status_code=404, + detail="from_message_id not found in this thread.", + ) + from_message_created_at = from_msg.created_at + if not from_msg.turn_id: + # Legacy row — surface the degradation in logs but let + # the request proceed with the slice-based delete and a + # cold-start checkpoint. + legacy_from_message = True + _logger.warning( + "[regenerate] from_message_id=%s on thread=%s has no " + "turn_id (legacy row pre-migration-136). Falling back " + "to slice-based delete without checkpoint rewind. " + "revert_actions=%s will be ignored.", + request.from_message_id, + thread_id, + request.revert_actions, + ) + else: + from_message_turn_id = from_msg.turn_id - if state_messages: - last_msg = state_messages[-1] - # Find a checkpoint where the last message is NOT a HumanMessage - # This means we're at a state before the user's last message - if not isinstance(last_msg, HumanMessage): - # If no new user_query provided (reload), extract from a later checkpoint - if user_query_to_use is None and i > 0: - # Get the user query from a more recent checkpoint - for prev_cp_tuple in checkpoint_tuples[:i]: - prev_checkpoint_data = prev_cp_tuple.checkpoint - prev_channel_values = prev_checkpoint_data.get( - "channel_values", {} - ) - prev_messages = prev_channel_values.get("messages", []) - for msg in reversed(prev_messages): - if isinstance(msg, HumanMessage): - q, imgs = split_langchain_human_content(msg.content) - user_query_to_use = q - regenerate_image_urls = imgs - break - if user_query_to_use is not None and ( - str(user_query_to_use).strip() or regenerate_image_urls - ): - break - - target_checkpoint_id = cp_tuple.config["configurable"][ + # Walk oldest-to-newest and pick the LAST checkpoint whose + # ``turn_id`` differs from the edited turn — that's the state + # immediately before this turn started running. We read from + # ``metadata`` (the durable surface) rather than + # ``config["configurable"]`` so the lookup works across + # checkpointer implementations. + target_checkpoint_id = _find_pre_turn_checkpoint_id( + checkpoint_tuples, + turn_id=from_message_turn_id, + ) + if target_checkpoint_id is None and len(checkpoint_tuples) > 0: + # Fall back to the oldest checkpoint — better than + # 400ing when the agent didn't checkpoint pre-turn + # (e.g. very first turn of the thread). + target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][ "checkpoint_id" ] - break + + # Look through checkpoints to find the right one + # We want to find the checkpoint just before the last HumanMessage. + # We enter this branch when: + # * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR + # * the client pinned ``from_message_id`` but the row is a + # legacy pre-migration-136 row with no ``turn_id`` (we + # downgraded to the same heuristic as a regular reload). + # We DO skip it when a real turn_id pinned ``target_checkpoint_id`` + # — that's the C1 happy path and the heuristic below would just + # re-derive a worse target. + if request.from_message_id is None or legacy_from_message: + for i, cp_tuple in enumerate(checkpoint_tuples): + # Access the checkpoint's channel_values which contains "messages" + checkpoint_data = cp_tuple.checkpoint + channel_values = checkpoint_data.get("channel_values", {}) + state_messages = channel_values.get("messages", []) + + if state_messages: + last_msg = state_messages[-1] + # Find a checkpoint where the last message is NOT a HumanMessage + # This means we're at a state before the user's last message + if not isinstance(last_msg, HumanMessage): + # If no new user_query provided (reload), extract from a later checkpoint + if user_query_to_use is None and i > 0: + # Get the user query from a more recent checkpoint + for prev_cp_tuple in checkpoint_tuples[:i]: + prev_checkpoint_data = prev_cp_tuple.checkpoint + prev_channel_values = prev_checkpoint_data.get( + "channel_values", {} + ) + prev_messages = prev_channel_values.get("messages", []) + for msg in reversed(prev_messages): + if isinstance(msg, HumanMessage): + q, imgs = split_langchain_human_content( + msg.content + ) + user_query_to_use = q + regenerate_image_urls = imgs + break + if user_query_to_use is not None and ( + str(user_query_to_use).strip() + or regenerate_image_urls + ): + break + + target_checkpoint_id = cp_tuple.config["configurable"][ + "checkpoint_id" + ] + break # If we couldn't find a good checkpoint, try alternative approaches if target_checkpoint_id is None and checkpoint_tuples: @@ -1472,18 +1821,51 @@ async def regenerate_response( detail="Could not determine user query for regeneration. Please provide a user_query.", ) - # Get the last two messages to delete AFTER streaming succeeds - # This prevents data loss if streaming fails - last_messages_result = await session.execute( - select(NewChatMessage) - .filter(NewChatMessage.thread_id == thread_id) - .order_by(NewChatMessage.created_at.desc()) - .limit(2) - ) + # Get the messages to delete AFTER streaming succeeds. + # This prevents data loss if streaming fails. + # + # When ``from_message_id`` is set we slice from that message + # forward (using ``created_at`` so we also catch any tool/system + # messages persisted into the same turn). Otherwise + # we keep the legacy "last 2 messages" rewind. + if request.from_message_id is not None and from_message_created_at is not None: + last_messages_result = await session.execute( + select(NewChatMessage) + .filter( + NewChatMessage.thread_id == thread_id, + NewChatMessage.created_at >= from_message_created_at, + ) + .order_by(NewChatMessage.created_at.desc()) + ) + else: + last_messages_result = await session.execute( + select(NewChatMessage) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at.desc()) + .limit(2) + ) messages_to_delete = list(last_messages_result.scalars().all()) message_ids_to_delete = [msg.id for msg in messages_to_delete] + # When revert_actions is requested, collect the set of + # ``chat_turn_id``s present in the slice we're about to delete. + # Each one will be reverted (best-effort) BEFORE the regenerate + # stream begins. Legacy rows have ``turn_id=None`` and silently + # contribute nothing — we already logged the degradation above. + revert_turn_ids: list[str] = [] + if ( + request.revert_actions + and request.from_message_id is not None + and not legacy_from_message + ): + seen_turns: set[str] = set() + for msg in messages_to_delete: + tid = msg.turn_id + if tid and tid not in seen_turns: + seen_turns.add(tid) + revert_turn_ids.append(tid) + # Get search space for LLM config search_space_result = await session.execute( select(SearchSpace).filter(SearchSpace.id == request.search_space_id) @@ -1507,6 +1889,24 @@ async def regenerate_response( # This prevents data loss if streaming fails (network error, LLM error, etc.) async def stream_with_cleanup(): streaming_completed = False + # Best-effort revert pass BEFORE the regenerate stream begins. + # Each turn is reverted independently (per-row SAVEPOINTs + # inside the route helper) and the per-action results are surfaced + # on a single ``data-revert-results`` SSE event so the frontend + # can render any failed rows alongside the new turn. Failures here + # do NOT abort the regeneration — partial rollback is documented + # behaviour. + if revert_turn_ids: + revert_results = await _revert_turns_for_regenerate( + thread_id=thread_id, + chat_turn_ids=revert_turn_ids, + requester_user_id=str(user.id), + ) + envelope = { + "type": "data-revert-results", + "data": revert_results, + } + yield f"data: {json.dumps(envelope, default=str)}\n\n".encode() try: async for chunk in stream_new_chat( user_query=str(user_query_to_use), diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 477fdf2ca..c7284e901 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): author_display_name: str | None = None author_avatar_url: str | None = None token_usage: TokenUsageSummary | None = None + # Per-turn correlation id (``f"{chat_id}:{ms}"``) from + # ``configurable.turn_id`` at streaming time. Nullable because + # legacy rows predate the column; clients should treat NULL as + # "edit-from-this-message is unavailable". + turn_id: str | None = None model_config = ConfigDict(from_attributes=True) @@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel): For edit, optional user_images (when not None) replaces image URLs resolved from checkpoint/DB so the client can send the full user turn (text and/or images). + + Edit-from-arbitrary-position. When ``from_message_id`` is provided + the route slices conversation history starting at that message (instead of + the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by + matching ``configurable.turn_id`` stored on the message (added in migration 136), and + optionally reverts every reversible action emitted in turns at or after + ``from_message_id``. The revert step is best-effort and runs BEFORE the + regenerate stream — partial failures are surfaced via SSE + ``data-revert-results`` and do not abort the regeneration. """ search_space_id: int @@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel): default=None, description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB", ) + from_message_id: int | None = Field( + default=None, + description=( + "Message id to rewind to. When set, history is sliced " + "from this message forward and the LangGraph checkpoint is " + "rewound to the state immediately preceding this turn. Legacy " + "rows that predate migration 136 have ``turn_id=None`` and " + "still process — the route logs a warning, skips the " + "checkpoint rewind, and ignores ``revert_actions`` (no " + "chat_turn_id available to walk)." + ), + ) + revert_actions: bool = Field( + default=False, + description=( + "When true, every reversible action emitted at or " + "after ``from_message_id`` is reverted before the regenerate " + "stream begins. Per-action results are surfaced via the " + "``data-revert-results`` SSE event. Partial failures DO NOT " + "abort the regeneration." + ), + ) @model_validator(mode="after") def _validate_regenerate_user_images(self) -> Self: @@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel): raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed") return self + @model_validator(mode="after") + def _validate_revert_actions_requires_from_message(self) -> Self: + if self.revert_actions and self.from_message_id is None: + raise ValueError( + "revert_actions requires from_message_id; specify which message to rewind to" + ) + return self + # ============================================================================= # Agent Tools Schemas diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 52a215997..5dbae91c5 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -584,13 +584,24 @@ class VercelStreamingService: # Tool Parts # ========================================================================= - def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str: + def format_tool_input_start( + self, + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + ) -> str: """ Format the start of tool input streaming. Args: - tool_call_id: The unique tool call identifier - tool_name: The name of the tool being called + tool_call_id: The unique tool call identifier (synthetic, derived + from LangGraph ``run_id`` so the frontend has a stable card id). + tool_name: The name of the tool being called. + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id``. When set, surfaces as + ``langchainToolCallId`` so the frontend can join this card + to the action-log row written by ``ActionLogMiddleware``. Returns: str: SSE formatted tool input start part @@ -598,13 +609,14 @@ class VercelStreamingService: Example output: data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"} """ - return self._format_sse( - { - "type": "tool-input-start", - "toolCallId": tool_call_id, - "toolName": tool_name, - } - ) + payload: dict[str, Any] = { + "type": "tool-input-start", + "toolCallId": tool_call_id, + "toolName": tool_name, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str: """ @@ -629,7 +641,12 @@ class VercelStreamingService: ) def format_tool_input_available( - self, tool_call_id: str, tool_name: str, input_data: dict[str, Any] + self, + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, ) -> str: """ Format the completion of tool input. @@ -638,6 +655,8 @@ class VercelStreamingService: tool_call_id: The tool call identifier tool_name: The name of the tool input_data: The complete tool input parameters + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id`` (see ``format_tool_input_start``). Returns: str: SSE formatted tool input available part @@ -645,22 +664,34 @@ class VercelStreamingService: Example output: data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}} """ - return self._format_sse( - { - "type": "tool-input-available", - "toolCallId": tool_call_id, - "toolName": tool_name, - "input": input_data, - } - ) + payload: dict[str, Any] = { + "type": "tool-input-available", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": input_data, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) - def format_tool_output_available(self, tool_call_id: str, output: Any) -> str: + def format_tool_output_available( + self, + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + ) -> str: """ Format tool execution output. Args: tool_call_id: The tool call identifier output: The tool execution result + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id`` extracted from ``ToolMessage.tool_call_id``. + When set, the frontend can backfill any card whose + ``langchainToolCallId`` was not yet known at + ``tool-input-start`` time. Returns: str: SSE formatted tool output available part @@ -668,13 +699,14 @@ class VercelStreamingService: Example output: data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}} """ - return self._format_sse( - { - "type": "tool-output-available", - "toolCallId": tool_call_id, - "output": output, - } - ) + payload: dict[str, Any] = { + "type": "tool-output-available", + "toolCallId": tool_call_id, + "output": output, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) # ========================================================================= # Step Parts diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index f3630e0b4..d02a31345 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -8,7 +8,9 @@ Operation outcomes mirror the plan: * **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from :class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows - written before the original mutation. + written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh + row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row + that was created; everything else is an in-place restore. * **Connector-owned actions with a declared ``reverse_descriptor``**: invoke the inverse tool through the agent's normal permission stack (NOT bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``. @@ -18,6 +20,11 @@ Operation outcomes mirror the plan: A successful revert appends a NEW row to ``agent_action_log`` with ``reverse_of=`` and the requesting user's ``user_id``, preserving an auditable chain. + +Dispatch must be exact-match (``tool_name == name``), NOT prefix matching. +``"rmdir".startswith("rm")`` would otherwise mis-route directory revert +to the document branch (and ``delete_note`` vs ``delete_folder`` is the +same trap waiting to happen). """ from __future__ import annotations @@ -25,17 +32,31 @@ from __future__ import annotations import logging from dataclasses import dataclass from datetime import UTC, datetime -from typing import Literal +from typing import Any, Literal -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + safe_filename, + safe_folder_segment, +) from app.db import ( AgentActionLog, + Chunk, + Document, DocumentRevision, + DocumentType, + Folder, FolderRevision, NewChatThread, ) +from app.utils.document_converters import ( + embed_texts, + generate_content_hash, + generate_unique_identifier_hash, +) logger = logging.getLogger(__name__) @@ -110,14 +131,244 @@ def can_revert( # --------------------------------------------------------------------------- -# Revert paths +# Helper: reconstruct virtual path from a snapshot # --------------------------------------------------------------------------- +async def _virtual_path_from_snapshot( + session: AsyncSession, + revision: DocumentRevision, +) -> str | None: + """Reconstruct the virtual_path the document was at before mutation. + + Preference order: + 1. ``metadata_before["virtual_path"]`` — written by every snapshot + helper since this PR. + 2. Compose ``"/"`` from + ``folder_id_before`` + ``title_before``. Walks the folder chain via + ``parent_id``. + """ + metadata = revision.metadata_before or {} + candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None + if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT): + return candidate + + title = revision.title_before + if not isinstance(title, str) or not title: + return None + + parts: list[str] = [] + cursor: int | None = revision.folder_id_before + visited: set[int] = set() + while cursor is not None and cursor not in visited: + visited.add(cursor) + folder = await session.get(Folder, cursor) + if folder is None: + return None + parts.append(safe_folder_segment(str(folder.name or ""))) + cursor = folder.parent_id + parts.reverse() + + base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT + filename = safe_filename(title) + return f"{base}/{filename}" + + +# --------------------------------------------------------------------------- +# Document revision restore (write/edit/move/rm) +# --------------------------------------------------------------------------- + + +def _set_field(target: Any, field: str, value: Any) -> None: + if value is not None: + setattr(target, field, value) + + +async def _restore_in_place_document( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Apply an in-place restore to an existing :class:`Document`.""" + if revision.document_id is None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Original document was hard-deleted; in-place restore is not possible." + ), + ) + doc = await session.get(Document, revision.document_id) + if doc is None: + return RevertOutcome( + status="tool_unavailable", + message="Original document has been deleted; revert cannot proceed.", + ) + + _set_field(doc, "content", revision.content_before) + _set_field(doc, "source_markdown", revision.content_before) + _set_field(doc, "title", revision.title_before) + _set_field(doc, "folder_id", revision.folder_id_before) + metadata_before = revision.metadata_before or {} + if isinstance(metadata_before, dict) and metadata_before: + doc.document_metadata = dict(metadata_before) + + if isinstance(revision.content_before, str): + doc.content_hash = generate_content_hash( + revision.content_before, doc.search_space_id + ) + + virtual_path = await _virtual_path_from_snapshot(session, revision) + if virtual_path: + doc.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + doc.search_space_id, + ) + + chunks_before = revision.chunks_before + if isinstance(chunks_before, list): + await session.execute(delete(Chunk).where(Chunk.document_id == doc.id)) + chunk_texts = [ + str(c.get("content")) + for c in chunks_before + if isinstance(c, dict) and isinstance(c.get("content"), str) + ] + if chunk_texts: + chunk_embeddings = embed_texts(chunk_texts) + session.add_all( + [ + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip( + chunk_texts, chunk_embeddings, strict=True + ) + ] + ) + if isinstance(revision.content_before, str): + doc.embedding = embed_texts([revision.content_before])[0] + + doc.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Document restored from snapshot.") + + +async def _reinsert_document_from_revision( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert).""" + if not isinstance(revision.title_before, str) or not revision.title_before: + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks title_before; cannot recreate document.", + ) + if not isinstance(revision.content_before, str): + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks content_before; cannot recreate document.", + ) + + virtual_path = await _virtual_path_from_snapshot(session, revision) + if not virtual_path: + return RevertOutcome( + status="not_reversible", + message=( + "Snapshot is missing both metadata_before['virtual_path'] AND " + "a resolvable (folder_id_before, title_before) pair." + ), + ) + + search_space_id = revision.search_space_id + unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + collision = await session.execute( + select(Document.id).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_identifier_hash, + ) + ) + if collision.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + f"A document already exists at '{virtual_path}'; revert would " + "collide. Move the live doc out of the way first." + ), + ) + + metadata = revision.metadata_before or {} + if not isinstance(metadata, dict): + metadata = {} + metadata = dict(metadata) + metadata["virtual_path"] = virtual_path + + content = revision.content_before + new_doc = Document( + title=revision.title_before, + document_type=DocumentType.NOTE, + document_metadata=metadata, + content=content, + content_hash=generate_content_hash(content, search_space_id), + unique_identifier_hash=unique_identifier_hash, + source_markdown=content, + search_space_id=search_space_id, + folder_id=revision.folder_id_before, + updated_at=datetime.now(UTC), + ) + session.add(new_doc) + await session.flush() + + new_doc.embedding = embed_texts([content])[0] + chunk_texts = [] + chunks_before = revision.chunks_before + if isinstance(chunks_before, list): + chunk_texts = [ + str(c.get("content")) + for c in chunks_before + if isinstance(c, dict) and isinstance(c.get("content"), str) + ] + if chunk_texts: + chunk_embeddings = embed_texts(chunk_texts) + session.add_all( + [ + Chunk(document_id=new_doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) + ] + ) + + # Repoint the snapshot at the recreated row so a follow-up revert of + # the same row works as expected. + revision.document_id = new_doc.id + return RevertOutcome( + status="ok", + message=f"Re-inserted document '{revision.title_before}' from snapshot.", + ) + + +async def _delete_created_document( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Delete the document that ``write_file`` created (``content_before IS NULL``).""" + if revision.document_id is None: + return RevertOutcome( + status="ok", + message="No live row to delete (already removed elsewhere).", + ) + await session.execute(delete(Document).where(Document.id == revision.document_id)) + return RevertOutcome( + status="ok", + message="Deleted the document that was created by this action.", + ) + + async def _restore_document_revision( session: AsyncSession, *, action: AgentActionLog ) -> RevertOutcome: - """Restore the most recent :class:`DocumentRevision` for ``action``.""" + """Dispatch document-level revert based on ``action.tool_name``.""" stmt = ( select(DocumentRevision) .where(DocumentRevision.agent_action_id == action.id) @@ -132,23 +383,111 @@ async def _restore_document_revision( message="No document_revisions row tied to this action.", ) - from app.db import Document # late import to avoid cycles at module load + tool_name = (action.tool_name or "").lower() - doc = await session.get(Document, revision.document_id) - if doc is None: + if tool_name == "rm": + return await _reinsert_document_from_revision(session, revision=revision) + + if tool_name == "write_file" and revision.content_before is None: + return await _delete_created_document(session, revision=revision) + + return await _restore_in_place_document(session, revision=revision) + + +# --------------------------------------------------------------------------- +# Folder revision restore (mkdir/rmdir/rename/move) +# --------------------------------------------------------------------------- + + +async def _restore_in_place_folder( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if revision.folder_id is None: return RevertOutcome( status="tool_unavailable", - message="Original document has been deleted; revert cannot proceed.", + message="Original folder was hard-deleted; in-place restore is impossible.", + ) + folder = await session.get(Folder, revision.folder_id) + if folder is None: + return RevertOutcome( + status="tool_unavailable", + message="Original folder has been deleted; revert cannot proceed.", + ) + _set_field(folder, "name", revision.name_before) + _set_field(folder, "parent_id", revision.parent_id_before) + _set_field(folder, "position", revision.position_before) + folder.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Folder restored from snapshot.") + + +async def _reinsert_folder_from_revision( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if not isinstance(revision.name_before, str) or not revision.name_before: + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks name_before; cannot recreate folder.", + ) + new_folder = Folder( + name=revision.name_before, + parent_id=revision.parent_id_before, + position=revision.position_before, + search_space_id=revision.search_space_id, + updated_at=datetime.now(UTC), + ) + session.add(new_folder) + await session.flush() + revision.folder_id = new_folder.id + return RevertOutcome( + status="ok", + message=f"Re-inserted folder '{revision.name_before}' from snapshot.", + ) + + +async def _delete_created_folder( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if revision.folder_id is None: + return RevertOutcome( + status="ok", + message="No live folder row to delete (already removed elsewhere).", + ) + folder_id = revision.folder_id + + has_doc = await session.execute( + select(Document.id).where(Document.folder_id == folder_id).limit(1) + ) + if has_doc.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Folder is no longer empty (documents have been added since " + "mkdir); cannot revert." + ), + ) + has_child = await session.execute( + select(Folder.id).where(Folder.parent_id == folder_id).limit(1) + ) + if has_child.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Folder is no longer empty (sub-folders have been added " + "since mkdir); cannot revert." + ), ) - if revision.content_before is not None: - doc.content = revision.content_before - if revision.title_before is not None: - doc.title = revision.title_before - if revision.folder_id_before is not None: - doc.folder_id = revision.folder_id_before - doc.updated_at = datetime.now(UTC) - return RevertOutcome(status="ok", message="Document restored from snapshot.") + await session.execute(delete(Folder).where(Folder.id == folder_id)) + return RevertOutcome( + status="ok", + message="Deleted the folder that was created by this action.", + ) async def _restore_folder_revision( @@ -168,41 +507,44 @@ async def _restore_folder_revision( message="No folder_revisions row tied to this action.", ) - from app.db import Folder + tool_name = (action.tool_name or "").lower() - folder = await session.get(Folder, revision.folder_id) - if folder is None: - return RevertOutcome( - status="tool_unavailable", - message="Original folder has been deleted; revert cannot proceed.", - ) + if tool_name == "rmdir": + return await _reinsert_folder_from_revision(session, revision=revision) - if revision.name_before is not None: - folder.name = revision.name_before - if revision.parent_id_before is not None: - folder.parent_id = revision.parent_id_before - if revision.position_before is not None: - folder.position = revision.position_before - folder.updated_at = datetime.now(UTC) - return RevertOutcome(status="ok", message="Folder restored from snapshot.") + if tool_name == "mkdir": + return await _delete_created_folder(session, revision=revision) + + return await _restore_in_place_folder(session, revision=revision) -# Tool-name prefixes that route to KB document / folder revert paths. Kept -# as data so a future PR adding new KB-owned tools doesn't have to touch -# this module's control flow. -_DOC_TOOL_PREFIXES: tuple[str, ...] = ( - "edit_file", - "write_file", - "update_memory", - "create_note", - "update_note", - "delete_note", +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- +# +# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``. +# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and +# ``delete_note``/``delete_folder``. + +_DOC_TOOLS: frozenset[str] = frozenset( + { + "edit_file", + "write_file", + "move_file", + "rm", + "update_memory", + "create_note", + "update_note", + "delete_note", + } ) -_FOLDER_TOOL_PREFIXES: tuple[str, ...] = ( - "mkdir", - "move_file", - "rename_folder", - "delete_folder", +_FOLDER_TOOLS: frozenset[str] = frozenset( + { + "mkdir", + "rmdir", + "rename_folder", + "delete_folder", + } ) @@ -220,9 +562,9 @@ async def revert_action( """ tool_name = (action.tool_name or "").lower() - if tool_name.startswith(_DOC_TOOL_PREFIXES): + if tool_name in _DOC_TOOLS: outcome = await _restore_document_revision(session, action=action) - elif tool_name.startswith(_FOLDER_TOOL_PREFIXES): + elif tool_name in _FOLDER_TOOLS: outcome = await _restore_folder_revision(session, action=action) elif action.reverse_descriptor: # Connector-owned reversibles run through the normal permission diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c254e66e2..2f8e33ba9 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -30,6 +30,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -70,6 +71,91 @@ _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: + """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. + + Returns a dict with three keys: + + * ``text`` — concatenated string content (empty string if the chunk + contributes none). + * ``reasoning`` — concatenated reasoning content (empty string if the + chunk contributes none). + * ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk`` + dicts surfaced from either the typed-block list or the + ``tool_call_chunks`` attribute. + + Background + ---------- + ``AIMessageChunk.content`` can be: + + * a ``str`` (most providers), or + * a ``list`` of typed blocks ``{type: 'text' | 'reasoning' | + 'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for + Anthropic, Bedrock, and several reasoning configurations. + + Reasoning may also live under + ``chunk.additional_kwargs['reasoning_content']`` (some providers + surface it that way instead of as a typed block). Tool-call chunks + may live under ``chunk.tool_call_chunks`` even when ``content`` is a + plain string. + + Earlier versions only handled the ``isinstance(content, str)`` branch + and silently dropped reasoning blocks + tool-call chunks emitted by + LangChain ``AIMessageChunk``s. + """ + out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} + if chunk is None: + return out + + content = getattr(chunk, "content", None) + if isinstance(content, str): + if content: + out["text"] = content + elif isinstance(content, list): + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + elif block_type == "reasoning": + value = ( + block.get("reasoning") + or block.get("text") + or block.get("content") + or "" + ) + if isinstance(value, str) and value: + reasoning_parts.append(value) + elif block_type in ("tool_call_chunk", "tool_use"): + out["tool_call_chunks"].append(block) + if text_parts: + out["text"] = "".join(text_parts) + if reasoning_parts: + out["reasoning"] = "".join(reasoning_parts) + + additional = getattr(chunk, "additional_kwargs", None) or {} + if isinstance(additional, dict): + extra_reasoning = additional.get("reasoning_content") + if isinstance(extra_reasoning, str) and extra_reasoning: + existing = out["reasoning"] + out["reasoning"] = ( + (existing + extra_reasoning) if existing else extra_reasoning + ) + + extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) + if isinstance(extra_tool_chunks, list): + for tcc in extra_tool_chunks: + if isinstance(tcc, dict): + out["tool_call_chunks"].append(tcc) + + return out + + def format_mentioned_surfsense_docs_as_context( documents: list[SurfsenseDocsDocument], ) -> str: @@ -266,6 +352,7 @@ async def _stream_agent_events( fallback_commit_search_space_id: int | None = None, fallback_commit_created_by_id: str | None = None, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + fallback_commit_thread_id: int | None = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -298,6 +385,41 @@ async def _stream_agent_events( active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool called_update_memory: bool = False + # Reasoning-block streaming. We open a reasoning block on the + # first reasoning delta of a step, append deltas as they arrive, and + # close it when text starts (the model has switched to writing its + # answer) or ``on_chat_model_end`` fires for the model node. Reuses + # the same Vercel format-helpers as text-start/delta/end. + current_reasoning_id: str | None = None + + # Streaming-parity v2 feature flag. When OFF we keep the legacy + # shape: str-only content, no reasoning blocks, no + # ``langchainToolCallId`` propagation. The schema migrations + # (135 / 136) ship unconditionally because they're forward-compatible. + parity_v2 = bool(get_flags().enable_stream_parity_v2) + + # Best-effort attach of LangChain ``tool_call_id`` to the synthetic + # ``call_`` card id we already emit. We accumulate + # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by + # name, and pop the next unconsumed entry at ``on_tool_start``. The + # authoritative id is later filled in at ``on_tool_end`` from + # ``ToolMessage.tool_call_id``. + pending_tool_call_chunks: list[dict[str, Any]] = [] + lc_tool_call_id_by_run: dict[str, str] = {} + + # Per-tool-end mutable cache for the LangChain tool_call_id resolved + # at ``on_tool_end``. ``_emit_tool_output`` reads this so every + # ``format_tool_output_available`` call automatically carries the + # authoritative id without duplicating the kwarg at every call site. + current_lc_tool_call_id: dict[str, str | None] = {"value": None} + + def _emit_tool_output(call_id: str, output: Any) -> str: + return streaming_service.format_tool_output_available( + call_id, + output, + langchain_tool_call_id=current_lc_tool_call_id["value"], + ) + def next_thinking_step_id() -> str: nonlocal thinking_step_counter thinking_step_counter += 1 @@ -326,22 +448,61 @@ async def _stream_agent_events( if "surfsense:internal" in event.get("tags", []): continue # Suppress middleware-internal LLM tokens (e.g. KB search classification) chunk = event.get("data", {}).get("chunk") - if chunk and hasattr(chunk, "content"): - content = chunk.content - if content and isinstance(content, str): - if current_text_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(current_text_id) - yield streaming_service.format_text_delta(current_text_id, content) - accumulated_text += content + if not chunk: + continue + parts = _extract_chunk_parts(chunk) + + # Accumulate any tool_call_chunks for best-effort + # correlation with ``on_tool_start`` below. We don't emit + # anything here; the matching is done at tool-start time. + if parity_v2 and parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + pending_tool_call_chunks.append(tcc) + + reasoning_delta = parts["reasoning"] + text_delta = parts["text"] + + # Reasoning streaming. Open a reasoning block on first + # delta; append every subsequent delta until text begins. + # When text starts we close the reasoning block first so the + # frontend sees the natural hand-off. Gated behind the + # parity-v2 flag so legacy deployments keep today's shape. + if parity_v2 and reasoning_delta: + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + if current_reasoning_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_reasoning_id = streaming_service.generate_reasoning_id() + yield streaming_service.format_reasoning_start(current_reasoning_id) + yield streaming_service.format_reasoning_delta( + current_reasoning_id, reasoning_delta + ) + + if text_delta: + if current_reasoning_id is not None: + yield streaming_service.format_reasoning_end(current_reasoning_id) + current_reasoning_id = None + if current_text_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(current_text_id) + yield streaming_service.format_text_delta(current_text_id, text_delta) + accumulated_text += text_delta elif event_type == "on_tool_start": active_tool_depth += 1 @@ -581,7 +742,39 @@ async def _stream_agent_events( if run_id else streaming_service.generate_tool_call_id() ) - yield streaming_service.format_tool_input_start(tool_call_id, tool_name) + + # Best-effort attach the LangChain ``tool_call_id``. We + # pop the first chunk in ``pending_tool_call_chunks`` whose + # name matches; if none match (the chunked args may not yet + # carry a ``name`` field, or the model skipped the chunked + # form) we leave ``langchainToolCallId`` unset for now and + # fill it in authoritatively at ``on_tool_end`` from + # ``ToolMessage.tool_call_id``. + langchain_tool_call_id: str | None = None + if parity_v2 and pending_tool_call_chunks: + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is not None: + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + langchain_tool_call_id = candidate + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + ) # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): @@ -598,6 +791,7 @@ async def _stream_agent_events( tool_call_id, tool_name, _safe_input, + langchain_tool_call_id=langchain_tool_call_id, ) elif event_type == "on_tool_end": @@ -639,6 +833,23 @@ async def _stream_agent_events( ) completed_step_ids.add(original_step_id) + # Authoritative LangChain tool_call_id from the returned + # ``ToolMessage``. Falls back to whatever we matched + # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) + # if the output isn't a ToolMessage. The value is stored in + # ``current_lc_tool_call_id`` so ``_emit_tool_output`` + # picks it up for every output emit below. Stays None when + # parity_v2 is off so legacy emit paths are untouched. + current_lc_tool_call_id["value"] = None + if parity_v2: + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + current_lc_tool_call_id["value"] = authoritative + if run_id: + lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in lc_tool_call_id_by_run: + current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] + if tool_name == "read_file": yield streaming_service.format_thinking_step( step_id=original_step_id, @@ -938,7 +1149,7 @@ async def _stream_agent_events( last_active_step_items = [] if tool_name == "generate_podcast": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -963,7 +1174,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_video_presentation": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -991,7 +1202,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_image": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1018,12 +1229,12 @@ async def _stream_agent_events( display_output["content_preview"] = ( content[:500] + "..." if len(content) > 500 else content ) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, display_output, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"result": tool_output}, ) @@ -1051,7 +1262,7 @@ async def _stream_agent_events( ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "status": "error", @@ -1060,7 +1271,7 @@ async def _stream_agent_events( }, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "status": "completed", @@ -1070,7 +1281,7 @@ async def _stream_agent_events( ) elif tool_name == "generate_report": # Stream the full report result so frontend can render the ReportCard - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1097,7 +1308,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_resume": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1148,7 +1359,7 @@ async def _stream_agent_events( "update_confluence_page", "delete_confluence_page", ): - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1176,7 +1387,7 @@ async def _stream_agent_events( if fpath and fpath not in result.sandbox_files: result.sandbox_files.append(fpath) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "exit_code": exit_code, @@ -1211,12 +1422,12 @@ async def _stream_agent_events( citations[chunk_url]["snippet"] = ( content[:200] + "…" if len(content) > 200 else content ) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"status": "completed", "citations": citations}, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"status": "completed", "result_length": len(str(tool_output))}, ) @@ -1274,6 +1485,25 @@ async def _stream_agent_events( }, ) + elif event_type == "on_custom_event" and event.get("name") == "action_log": + # Surface a freshly committed AgentActionLog row so the chat + # tool card can render its Revert button immediately. + data = event.get("data", {}) + if data.get("id") is not None: + yield streaming_service.format_data("action-log", data) + + elif ( + event_type == "on_custom_event" + and event.get("name") == "action_log_updated" + ): + # Reversibility flipped in kb_persistence after the SAVEPOINT + # for a destructive op (rm/rmdir/move/edit/write) committed. + # Frontend uses this to flip the card's Revert + # button on without re-fetching the actions list. + data = event.get("data", {}) + if data.get("id") is not None: + yield streaming_service.format_data("action-log-updated", data) + elif event_type in ("on_chain_end", "on_agent_end"): if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -1291,11 +1521,12 @@ async def _stream_agent_events( # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work - # (dirty_paths / staged_dirs / pending_moves) will still be in the - # checkpointed state. Run the SAME shared commit helper here so the - # turn's writes don't get lost on client disconnect, then push the - # delta back into the graph using `as_node=...` so reducers fire as if - # the after_agent hook produced it. + # (dirty_paths / staged_dirs / pending_moves / pending_deletes / + # pending_dir_deletes) will still be in the checkpointed state. Run + # the SAME shared commit helper here so the turn's writes don't get + # lost on client disconnect, then push the delta back into the graph + # using `as_node=...` so reducers fire as if the after_agent hook + # produced it. if ( fallback_commit_filesystem_mode == FilesystemMode.CLOUD and fallback_commit_search_space_id is not None @@ -1303,6 +1534,8 @@ async def _stream_agent_events( (state_values.get("dirty_paths") or []) or (state_values.get("staged_dirs") or []) or (state_values.get("pending_moves") or []) + or (state_values.get("pending_deletes") or []) + or (state_values.get("pending_dir_deletes") or []) ) ): try: @@ -1311,6 +1544,7 @@ async def _stream_agent_events( search_space_id=fallback_commit_search_space_id, created_by_id=fallback_commit_created_by_id, filesystem_mode=fallback_commit_filesystem_mode, + thread_id=fallback_commit_thread_id, dispatch_events=False, ) if delta: @@ -1726,6 +1960,17 @@ async def stream_new_chat( yield streaming_service.format_message_start() yield streaming_service.format_start_step() + # Surface the per-turn correlation id at the very start of the + # stream so the frontend can stamp it onto the in-flight + # assistant message and replay it via ``appendMessage`` + # for durable storage. Tool/action-log events DO carry it later, + # but pure-text turns never produce action-log events; this + # event guarantees the frontend learns the turn id regardless. + yield streaming_service.format_data( + "turn-info", + {"chat_turn_id": stream_result.turn_id}, + ) + # Initial thinking step - analyzing the request if mentioned_surfsense_docs: initial_title = "Analyzing referenced content" @@ -1876,6 +2121,7 @@ async def stream_new_chat( if filesystem_selection else FilesystemMode.CLOUD ), + fallback_commit_thread_id=chat_id, ): if not _first_event_logged: _perf_log.info( @@ -2308,6 +2554,13 @@ async def stream_resume_chat( yield streaming_service.format_message_start() yield streaming_service.format_start_step() + # Same rationale as ``stream_new_chat``: emit the turn id so + # resumed streams can be persisted with their correlation id + # intact. + yield streaming_service.format_data( + "turn-info", + {"chat_turn_id": stream_result.turn_id}, + ) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -2325,6 +2578,7 @@ async def stream_resume_chat( if filesystem_selection else FilesystemMode.CLOUD ), + fallback_commit_thread_id=chat_id, ): if not _first_event_logged: _perf_log.info( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py index aad1524c9..8ef1430a9 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware from app.agents.new_chat.tools.registry import ToolDefinition +@dataclass +class _FakeRuntime: + """Minimal stand-in for ``ToolRuntime`` used in unit tests. + + ``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']`` + to populate the new ``chat_turn_id`` column (see migration 135). + """ + + config: dict[str, Any] | None = None + + @dataclass class _FakeRequest: """Minimal stand-in for ToolCallRequest used in unit tests.""" @@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence: "args": {"color": "red", "size": 3}, "id": "tc-abc", }, + runtime=_FakeRuntime( + config={"configurable": {"turn_id": "42:1700000000000"}} + ), ) result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1") handler = AsyncMock(return_value=result_msg) @@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence: assert row.error is None assert row.reverse_descriptor is None assert row.reversible is False + # Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``; + # ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``. + assert row.tool_call_id == "tc-abc" + assert row.turn_id == "tc-abc" + assert row.chat_turn_id == "42:1700000000000" + + @pytest.mark.asyncio + async def test_chat_turn_id_none_when_runtime_missing( + self, patch_get_flags, fake_session_factory + ) -> None: + """``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent.""" + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc-1"}, + runtime=None, + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.tool_call_id == "tc-1" + assert row.chat_turn_id is None @pytest.mark.asyncio async def test_writes_row_on_failure_and_reraises( @@ -293,6 +333,76 @@ class TestReverseDescriptor: assert row.reversible is False +class TestActionLogDispatch: + """Verify ``adispatch_custom_event`` fires after commit.""" + + @pytest.mark.asyncio + async def test_dispatches_action_log_event_on_success( + self, patch_get_flags, fake_session_factory + ) -> None: + _captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "red"}, + "id": "tc-evt", + }, + runtime=_FakeRuntime( + config={"configurable": {"turn_id": "42:1700000000000"}} + ), + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42") + handler = AsyncMock(return_value=result_msg) + + dispatch_mock = AsyncMock() + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + patch( + "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + dispatch_mock, + ), + ): + await mw.awrap_tool_call(request, handler) + + dispatch_mock.assert_awaited_once() + call_args = dispatch_mock.await_args + assert call_args is not None + assert call_args.args[0] == "action_log" + payload = call_args.args[1] + assert payload["lc_tool_call_id"] == "tc-evt" + assert payload["chat_turn_id"] == "42:1700000000000" + assert payload["tool_name"] == "make_widget" + assert payload["reversible"] is False + assert payload["reverse_descriptor_present"] is False + assert payload["error"] is False + + @pytest.mark.asyncio + async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None: + """If commit fails the dispatch is suppressed (no row to surface).""" + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + dispatch_mock = AsyncMock() + + def _exploding_session(): + raise RuntimeError("DB is down") + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=_exploding_session), + patch( + "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + dispatch_mock, + ), + ): + await mw.awrap_tool_call(request, handler) + dispatch_mock.assert_not_awaited() + + class TestArgsTruncation: @pytest.mark.asyncio async def test_huge_args_payload_is_truncated( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py new file mode 100644 index 000000000..653175eab --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py @@ -0,0 +1,122 @@ +"""Tests for the desktop-mode safety ruleset. + +In desktop mode the agent operates against the user's real disk with no +revision history, so destructive filesystem operations must require +explicit approval. These tests pin the set of tools that get the ``ask`` +gate so it cannot silently regress. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) + +pytestmark = pytest.mark.unit + + +# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking`` +# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a +# copy here means the rule contract has a focused regression test even when +# the larger graph-build helper is hard to instantiate in unit tests. +DESKTOP_SAFETY_RULESET = Ruleset( + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", +) + +SURFSENSE_DEFAULTS = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", +) + + +def _action_for(tool_name: str, *rulesets: Ruleset) -> str: + rules = evaluate_many(tool_name, [tool_name], *rulesets) + return aggregate_action(rules) + + +class TestDesktopSafetyRulesGateDestructiveOps: + @pytest.mark.parametrize( + "tool_name", + ["rm", "rmdir", "move_file", "edit_file", "write_file"], + ) + def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None: + # surfsense_defaults says "allow */*"; desktop_safety must override + # because it's layered later (last-match-wins). + action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "ask", ( + f"{tool_name} must require approval in desktop mode " + f"(no revert path on real disk); got {action!r}" + ) + + @pytest.mark.parametrize( + "tool_name", + ["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"], + ) + def test_safe_ops_remain_allowed(self, tool_name: str) -> None: + # Read-only and trivially-reversible tools must NOT get gated — + # otherwise every navigation in desktop mode pops an interrupt. + action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "allow", ( + f"{tool_name} should not be gated in desktop mode; got {action!r}" + ) + + +class TestDesktopSafetyOverridesAllowDefault: + def test_layer_order_last_match_wins(self) -> None: + # If desktop_safety is layered BEFORE surfsense_defaults, the allow + # default would win and the safety net would be inert. This test + # protects against accidentally swapping the rulesets in + # ``_build_compiled_agent_blocking``. + action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS) + # Layered "wrong way" — the broad allow now wins. + assert action == "allow" + + # Correct order: defaults < desktop_safety -> ask wins. + action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "ask" + + +class TestPermissionMiddlewareIntegration: + def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None: + from langchain_core.messages import AIMessage + + from app.agents.new_chat.errors import RejectedError + + mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET]) + # Stub the interrupt to a "reject" decision so we can assert the + # ask path was taken without spinning up the LangGraph runtime. + mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] + + state = { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": "rm", + "args": {"path": "/Users/me/Documents/important.docx"}, + "id": "tc-rm", + } + ], + ) + ] + } + + class _FakeRuntime: + config: dict = {"configurable": {"thread_id": "test"}} + + with pytest.raises(RejectedError): + mw.after_model(state, _FakeRuntime()) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py new file mode 100644 index 000000000..0bbdf37bf --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py @@ -0,0 +1,111 @@ +"""Tests for the default auto-approval list in ``hitl.request_approval``. + +These pin the policy that low-stakes connector creation tools (drafts, +new-file creates) skip the HITL interrupt by default. Without this set, +every "draft my newsletter" turn used to fire ~3 interrupts before any +useful work happened. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.tools.hitl import ( + DEFAULT_AUTO_APPROVED_TOOLS, + HITLResult, + request_approval, +) + +pytestmark = pytest.mark.unit + + +class TestDefaultAutoApprovedToolsList: + def test_set_contains_expected_creation_tools(self) -> None: + # If anyone changes the policy list, we want a single test to + # update so the contract is explicit. Keep this in sync with + # ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``. + expected = { + "create_gmail_draft", + "update_gmail_draft", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } + assert expected == DEFAULT_AUTO_APPROVED_TOOLS + + def test_set_is_immutable(self) -> None: + # frozenset prevents accidental at-runtime mutation that would + # silently widen the auto-approval surface. + assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset) + + def test_send_tools_are_not_auto_approved(self) -> None: + # External-broadcast tools must always prompt. + for tool_name in ( + "send_gmail_email", + "send_discord_message", + "send_teams_message", + "delete_notion_page", + "create_calendar_event", + "delete_calendar_event", + ): + assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, ( + f"{tool_name} must remain HITL-gated" + ) + + +class TestRequestApprovalAutoBypass: + def test_auto_approved_tool_skips_interrupt(self) -> None: + # No interrupt mock set up — if the function attempted to call + # ``langgraph.types.interrupt`` it would raise GraphInterrupt. + # The fact that we get a clean HITLResult proves the bypass. + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={"to": "alice@example.com", "subject": "hi", "body": "hey"}, + ) + assert isinstance(result, HITLResult) + assert result.rejected is False + assert result.decision_type == "auto_approved" + # Original params are preserved untouched (no user edits possible). + assert result.params == { + "to": "alice@example.com", + "subject": "hi", + "body": "hey", + } + + def test_non_listed_tool_still_attempts_interrupt(self) -> None: + # A tool NOT in the default list must reach ``langgraph.interrupt``. + # Outside a runnable context that call raises a RuntimeError — + # which is exactly the signal we want: the bypass did NOT fire. + with pytest.raises(RuntimeError, match="runnable context"): + request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={"to": "alice@example.com", "subject": "hi", "body": "hey"}, + ) + + def test_user_trusted_tools_still_take_precedence(self) -> None: + # ``trusted_tools`` (per-connector "always allow" from MCP/UI) + # was checked BEFORE the default list and must keep working + # for tools outside the default list. + result = request_approval( + action_type="mcp_tool_call", + tool_name="my_custom_mcp_tool", + params={"x": 1}, + trusted_tools=["my_custom_mcp_tool"], + ) + assert result.decision_type == "trusted" + assert result.rejected is False + + def test_auto_approved_overrides_no_trusted_tools(self) -> None: + # When trusted_tools is empty and tool is in the default list, + # we should still bypass — proves the order in request_approval. + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={"title": "Plan"}, + trusted_tools=[], + ) + assert result.decision_type == "auto_approved" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py b/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py new file mode 100644 index 000000000..7cabb6524 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py @@ -0,0 +1,333 @@ +"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools. + +The tools build ``Command(update=...)`` payloads that the persistence +middleware applies at end of turn. These tests stub out the backend and +runtime to assert the staging payload shape: + +* ``rm`` queues into ``pending_deletes`` and tombstones state files. +* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc. +* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs. +* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete. +* ``rmdir`` refuses to drop the cwd or any of its ancestors. +* ``KBPostgresBackend`` view-helpers honor staged deletes. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware +from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend + +pytestmark = pytest.mark.unit + + +def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._filesystem_mode = mode + middleware._custom_tool_descriptions = {} + return middleware + + +def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"): + state = state or {} + state.setdefault("cwd", "/documents") + return SimpleNamespace(state=state, tool_call_id=tool_call_id) + + +class _KBBackendStub(KBPostgresBackend): + """Construct-able subclass of :class:`KBPostgresBackend` for tests. + + We bypass the real ``__init__`` (which expects a runtime + DB session) + and inject just the methods the rm/rmdir tools touch. The class + inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks + inside the tools happy, which is what gates them from the desktop + code path. + """ + + def __init__(self, *, children=None, file_data=None) -> None: + self.als_info = AsyncMock(return_value=children or []) + self._load_file_data = AsyncMock( + return_value=(file_data, 17) if file_data is not None else None + ) + + +def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend: + return _KBBackendStub(children=children, file_data=file_data) + + +def _bind_backend(middleware, backend): + """Inject a backend resolver onto the middleware test instance.""" + middleware._get_backend = lambda runtime: backend + return backend + + +# --------------------------------------------------------------------------- +# rm +# --------------------------------------------------------------------------- + + +class TestRmStaging: + @pytest.mark.asyncio + async def test_stages_delete_and_tombstones_state(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime( + { + "cwd": "/documents", + "files": {"/documents/notes.md": {"content": ["hello"]}}, + "doc_id_by_path": {"/documents/notes.md": 17}, + }, + tool_call_id="tc-1", + ) + + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + + assert hasattr(result, "update"), f"expected Command, got {result!r}" + update = result.update + assert update["pending_deletes"] == [ + {"path": "/documents/notes.md", "tool_call_id": "tc-1"} + ] + assert update["files"] == {"/documents/notes.md": None} + assert update["doc_id_by_path"] == {"/documents/notes.md": None} + + @pytest.mark.asyncio + async def test_rejects_documents_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/documents", runtime=runtime) + assert isinstance(result, str) + assert "refusing to rm" in result + + @pytest.mark.asyncio + async def test_rejects_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/", runtime=runtime) + assert isinstance(result, str) + assert "refusing to rm" in result + + @pytest.mark.asyncio + async def test_rejects_directory_via_staged_dirs(self): + m = _make_middleware() + runtime = _runtime( + { + "staged_dirs": ["/documents/team-x"], + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/team-x", runtime=runtime) + assert isinstance(result, str) + assert "directory" in result.lower() + assert "rmdir" in result + + @pytest.mark.asyncio + async def test_rejects_directory_via_listing(self): + m = _make_middleware() + _bind_backend( + m, + _make_backend_stub( + children=[{"path": "/documents/foo/x.md", "is_dir": False}] + ), + ) + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/foo", runtime=runtime) + assert isinstance(result, str) + assert "directory" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_anonymous_doc(self): + m = _make_middleware() + runtime = _runtime( + { + "kb_anon_doc": { + "path": "/documents/uploaded.xml", + "title": "uploaded", + "content": "", + "chunks": [], + } + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime) + assert isinstance(result, str) + assert "read-only" in result + + @pytest.mark.asyncio + async def test_drops_path_from_dirty_paths(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime( + { + "files": {"/documents/notes.md": {"content": ["x"]}}, + "doc_id_by_path": {"/documents/notes.md": 17}, + "dirty_paths": ["/documents/notes.md"], + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + update = result.update + # First element is _CLEAR sentinel; the rest must NOT contain the + # rm'd path. + dirty = update.get("dirty_paths") or [] + assert "/documents/notes.md" not in dirty[1:] + + +# --------------------------------------------------------------------------- +# rmdir +# --------------------------------------------------------------------------- + + +class TestRmdirStaging: + @pytest.mark.asyncio + async def test_stages_dir_delete_when_empty_and_db_backed(self): + m = _make_middleware() + backend = _bind_backend(m, _make_backend_stub(children=[])) + # Override _load_file_data to return None (folder, not a file) and + # parent listing to claim the folder exists. + backend._load_file_data = AsyncMock(return_value=None) + backend.als_info = AsyncMock( + side_effect=[ + [], # children of /documents/proj + [ + {"path": "/documents/proj", "is_dir": True}, + ], # parent listing + ] + ) + runtime = _runtime( + { + "cwd": "/documents", + }, + tool_call_id="tc-rd", + ) + + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + + assert hasattr(result, "update") + update = result.update + assert update["pending_dir_deletes"] == [ + {"path": "/documents/proj", "tool_call_id": "tc-rd"} + ] + + @pytest.mark.asyncio + async def test_rejects_non_empty(self): + m = _make_middleware() + _bind_backend( + m, + _make_backend_stub( + children=[{"path": "/documents/proj/x.md", "is_dir": False}] + ), + ) + runtime = _runtime() + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "not empty" in result + + @pytest.mark.asyncio + async def test_unstages_same_turn_mkdir(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[])) + runtime = _runtime( + { + "cwd": "/documents", + "staged_dirs": ["/documents/scratch"], + }, + tool_call_id="tc-rd", + ) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/scratch", runtime=runtime) + + assert hasattr(result, "update") + update = result.update + assert "pending_dir_deletes" not in update + # _CLEAR sentinel + remaining items (in this case, none). + staged_after = update["staged_dirs"] + assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00" + assert "/documents/scratch" not in staged_after[1:] + + @pytest.mark.asyncio + async def test_rejects_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rmdir_tool() + for victim in ("/", "/documents"): + result = await tool.coroutine(victim, runtime=runtime) + assert isinstance(result, str) + assert "refusing to rmdir" in result + + @pytest.mark.asyncio + async def test_rejects_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/proj"}) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "cwd" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_ancestor_of_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/proj/sub"}) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "cwd" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_files(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime() + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + assert isinstance(result, str) + assert "is a file" in result + + +# --------------------------------------------------------------------------- +# KBPostgresBackend view filter +# --------------------------------------------------------------------------- + + +class TestKBPostgresBackendDeleteFilter: + """als_info / glob / grep should suppress paths queued for delete.""" + + def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend: + runtime = SimpleNamespace(state=state) + backend = KBPostgresBackend(search_space_id=1, runtime=runtime) + return backend + + def test_pending_filesystem_view_returns_deleted_paths(self): + backend = self._make_backend( + { + "pending_deletes": [ + {"path": "/documents/x.md", "tool_call_id": "t1"}, + ], + "pending_dir_deletes": [ + {"path": "/documents/d1", "tool_call_id": "t2"}, + ], + } + ) + removed, alias, deleted_dirs = backend._pending_filesystem_view({}) + assert "/documents/x.md" in removed + assert "/documents/d1" in deleted_dirs + assert alias == {} + + def test_dir_suppressed_covers_descendants(self): + backend = self._make_backend({}) + deleted_dirs = {"/documents/d"} + assert backend._is_dir_suppressed("/documents/d", deleted_dirs) + assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs) + assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs) + assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py index 3caeb9a34..185753990 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py @@ -98,10 +98,54 @@ class TestInitialFilesystemState: state = _initial_filesystem_state() assert state["cwd"] == "/documents" assert state["staged_dirs"] == [] + assert state["staged_dir_tool_calls"] == {} assert state["pending_moves"] == [] + assert state["pending_deletes"] == [] + assert state["pending_dir_deletes"] == [] assert state["doc_id_by_path"] == {} assert state["dirty_paths"] == [] + assert state["dirty_path_tool_calls"] == {} assert state["kb_priority"] == [] assert state["kb_matched_chunk_ids"] == {} assert state["kb_anon_doc"] is None assert state["tree_version"] == 0 + + +class TestMultiEditSamePathCoalescing: + """Multi-edit-same-path turns must coalesce into ONE binding record. + + The persistence body uses ``dirty_path_tool_calls[path]`` to find the + tool_call_id that produced the current state on disk. Because + ``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second + edit doesn't append a new path entry — and because + ``_dict_merge_with_tombstones_reducer`` lets the right-hand side + overwrite, the LATEST tool_call_id wins. That's the correct behavior + for snapshotting: revert restores to the pre-mutation state, and + multiple back-to-back edits in one turn coalesce into a single + revisible op (the user sees ONE Revert button per turn-per-path, + not N). + """ + + def test_dirty_paths_dedupes_repeated_writes(self): + # ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes + # to the same path produce one entry, not two. + first = _add_unique_reducer([], ["/documents/a.md"]) + second = _add_unique_reducer(first, ["/documents/a.md"]) + assert second == ["/documents/a.md"] + + def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self): + # First write tags the path with tcid-1. + merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"}) + # Second write to the same path tags it with tcid-2 (latest wins). + merged = _dict_merge_with_tombstones_reducer( + merged, {"/documents/a.md": "tcid-2"} + ) + assert merged == {"/documents/a.md": "tcid-2"} + + def test_rm_tombstones_dirty_path_tool_call(self): + # ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to + # prevent a stale binding from leaking past the delete. + merged = _dict_merge_with_tombstones_reducer( + {"/documents/a.md": "tcid-1"}, {"/documents/a.md": None} + ) + assert merged == {} diff --git a/surfsense_backend/tests/unit/db/__init__.py b/surfsense_backend/tests/unit/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py b/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py new file mode 100644 index 000000000..82c299488 --- /dev/null +++ b/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py @@ -0,0 +1,83 @@ +"""Smoke test for the ``134_relax_revision_fks`` Alembic migration. + +A full apply/rollback test would require a live Postgres; here we verify +the migration module's static contract: + +* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``. +* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'`` + (one for ``document_revisions.document_id``, one for + ``folder_revisions.folder_id``). +* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining + orphaned revisions. + +If any of these invariants regress the snapshot/revert pipeline silently +loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the +migration "down" or never ran it at all. +""" + +from __future__ import annotations + +import importlib.util +import inspect +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +_MIGRATION_PATH = ( + Path(__file__).resolve().parents[3] + / "alembic" + / "versions" + / "134_relax_revision_fks.py" +) + + +def _load_migration(): + """Load the migration module by file path (no package import needed).""" + spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH) + assert spec and spec.loader, "could not load migration spec" + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_migration_chain_revision_ids() -> None: + module = _load_migration() + # The migration file uses short numeric revision IDs to match the + # in-tree convention (cf. ``133`` -> ``134``); the ``134_.py`` + # filename is documentation, not the canonical revision string. + assert getattr(module, "revision", None) == "134" + assert getattr(module, "down_revision", None) == "133" + + +def test_migration_exposes_upgrade_and_downgrade() -> None: + module = _load_migration() + upgrade = getattr(module, "upgrade", None) + downgrade = getattr(module, "downgrade", None) + assert callable(upgrade), "upgrade() is required" + assert callable(downgrade), "downgrade() is required" + + +def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None: + module = _load_migration() + src = inspect.getsource(module.upgrade) + assert "document_revisions" in src + assert "folder_revisions" in src + # Both new FKs MUST be ON DELETE SET NULL — that's the entire point + # of the migration: snapshots must outlive their parent row. + assert src.count('ondelete="SET NULL"') >= 2 + # And the ``document_id`` / ``folder_id`` columns become nullable. + assert "nullable=True" in src + + +def test_downgrade_drains_orphans_then_restores_cascade() -> None: + module = _load_migration() + src = inspect.getsource(module.downgrade) + # Drain orphaned rows BEFORE we can re-impose NOT NULL. + assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src + assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src + # Then restore the original CASCADE/NOT NULL contract. + assert src.count('ondelete="CASCADE"') >= 2 + assert "nullable=False" in src diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py index c2e304399..70430f4ca 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py @@ -168,6 +168,8 @@ class TestModeSpecificPrompts: "edit_file", "move_file", "mkdir", + "rm", + "rmdir", "list_tree", "grep", ): @@ -182,6 +184,8 @@ class TestModeSpecificPrompts: "edit_file", "move_file", "mkdir", + "rm", + "rmdir", "list_tree", "grep", ): @@ -190,6 +194,18 @@ class TestModeSpecificPrompts: assert "/documents/" not in text, f"{name} mentions cloud namespace" assert "temp_" not in text, f"{name} mentions cloud temp_ semantics" + def test_cloud_descs_include_rm_and_rmdir(self): + descs = _build_tool_descriptions(FilesystemMode.CLOUD) + assert "rm" in descs and "rmdir" in descs + assert "Deletes a single file" in descs["rm"] + assert "Deletes an empty directory" in descs["rmdir"] + assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"] + + def test_desktop_descs_warn_about_irreversibility(self): + descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) + assert "NOT reversible" in descs["rm"] + assert "NOT reversible" in descs["rmdir"] + def test_sandbox_addendum_appended_when_available(self): prompt = _build_filesystem_system_prompt( FilesystemMode.CLOUD, sandbox_available=True diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py new file mode 100644 index 000000000..feca23d27 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py @@ -0,0 +1,309 @@ +"""Unit tests for the kb_persistence snapshot helpers. + +The full ``commit_staged_filesystem_state`` body exercises a real session +in integration tests; here we verify the building blocks used by the +snapshot/revert pipeline: + +* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids + (regression guard against the N+1 lookup pattern). +* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``. +* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the + shape the snapshot helpers consume. + +These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so +the assertions run in milliseconds and don't require Postgres. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.agents.new_chat.middleware import kb_persistence + +pytestmark = pytest.mark.unit + + +class _FakeResult: + def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None: + self._rows = rows or [] + self._scalar = scalar + + def all(self) -> list[Any]: + return list(self._rows) + + def scalar_one_or_none(self) -> Any: + return self._scalar + + +class _FakeSession: + def __init__(self) -> None: + self.execute = AsyncMock() + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_issues_single_query() -> None: + """The lookup MUST be a single ``IN (...)`` SELECT, not N selects.""" + session = _FakeSession() + session.execute.return_value = _FakeResult( + rows=[ + MagicMock(id=11, tool_call_id="tc-a"), + MagicMock(id=22, tool_call_id="tc-b"), + MagicMock(id=33, tool_call_id="tc-c"), + ] + ) + + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=1, + tool_call_ids={"tc-a", "tc-b", "tc-c"}, + ) + + assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33} + assert session.execute.await_count == 1, ( + "Snapshot binding must batch into ONE query; got " + f"{session.execute.await_count} (regression: N+1 lookup pattern)." + ) + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None: + session = _FakeSession() + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=None, + tool_call_ids={"tc-a"}, + ) + assert mapping == {} + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None: + session = _FakeSession() + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=42, + tool_call_ids=set(), + ) + assert mapping == {} + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_mark_action_reversible_is_noop_for_null_id() -> None: + session = _FakeSession() + await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type] + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_mark_action_reversible_runs_update_for_real_id() -> None: + session = _FakeSession() + await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type] + assert session.execute.await_count == 1 + + +def test_doc_revision_payload_captures_metadata_virtual_path() -> None: + """Snapshot helpers must capture ``metadata_before`` for revert reuse.""" + doc = MagicMock() + doc.content = "body" + doc.title = "notes.md" + doc.folder_id = 7 + doc.document_metadata = {"virtual_path": "/documents/team/notes.md"} + + payload = kb_persistence._doc_revision_payload( + doc, chunks_before=[{"content": "x"}] + ) + + assert payload["title_before"] == "notes.md" + assert payload["folder_id_before"] == 7 + assert payload["content_before"] == "body" + assert payload["chunks_before"] == [{"content": "x"}] + assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"} + + +def test_doc_revision_payload_handles_missing_metadata() -> None: + doc = MagicMock() + doc.content = "" + doc.title = "" + doc.folder_id = None + doc.document_metadata = None + payload = kb_persistence._doc_revision_payload(doc) + assert payload["metadata_before"] is None + + +@pytest.mark.asyncio +async def test_load_chunks_for_snapshot_returns_content_only() -> None: + """Snapshot chunks intentionally omit embeddings (regenerated on revert).""" + session = _FakeSession() + session.execute.return_value = _FakeResult( + rows=[ + MagicMock(content="alpha"), + MagicMock(content="beta"), + ] + ) + chunks = await kb_persistence._load_chunks_for_snapshot( + session, + doc_id=42, # type: ignore[arg-type] + ) + assert chunks == [{"content": "alpha"}, {"content": "beta"}] + + +# --------------------------------------------------------------------------- +# Deferred reversibility-flip dispatches. +# +# The snapshot helpers used to dispatch ``action_log_updated`` directly +# from inside the SAVEPOINT block. That meant the SSE side-channel +# could tell the UI a row was reversible while the OUTER transaction +# was still pending — and if the outer commit failed, every SAVEPOINT +# rolled back too, leaving the UI in a state inconsistent with +# durable storage. The deferred-dispatch contract fixes that: +# +# • when a ``deferred_dispatches`` list is provided, the helper +# APPENDS the action_id and does NOT dispatch; +# • the caller (``commit_staged_filesystem_state``) flushes the list +# only AFTER ``await session.commit()`` succeeds; on rollback it +# clears the list so nothing is emitted. +# --------------------------------------------------------------------------- + + +class _NestedCtx: + """Async context manager mimicking ``session.begin_nested()``.""" + + async def __aenter__(self) -> _NestedCtx: + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + +@pytest.mark.asyncio +async def test_pre_write_snapshot_defers_dispatch_when_list_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Helpers MUST queue dispatches when ``deferred_dispatches`` is set.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock(return_value=_FakeResult(rows=[])) + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 17 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + deferred: list[int] = [] + doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"}) + doc.title = "x.md" + doc.folder_id = None + doc.content = "body" + + rev_id = await kb_persistence._snapshot_document_pre_write( + session, # type: ignore[arg-type] + doc=doc, + action_id=42, + search_space_id=1, + turn_id="t-1", + deferred_dispatches=deferred, + ) + + assert rev_id == 17 + # Inline dispatch must NOT have fired; the action_id is queued. + assert dispatched == [] + assert deferred == [42] + + +@pytest.mark.asyncio +async def test_pre_write_snapshot_dispatches_inline_when_list_omitted( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Direct callers (no outer transaction) keep the legacy inline dispatch.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock(return_value=_FakeResult(rows=[])) + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 7 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"}) + doc.title = "y.md" + doc.folder_id = None + doc.content = "body" + + await kb_persistence._snapshot_document_pre_write( + session, # type: ignore[arg-type] + doc=doc, + action_id=88, + search_space_id=1, + turn_id="t-1", + # No deferred_dispatches arg — fall back to inline dispatch. + ) + + assert dispatched == [88] + + +@pytest.mark.asyncio +async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Folder mkdir snapshots honour the same deferred-dispatch contract.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock() # _mark_action_reversible calls execute + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 3 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + deferred: list[int] = [] + folder = MagicMock(id=2, name="f", parent_id=None, position="a0") + + await kb_persistence._snapshot_folder_pre_mkdir( + session, # type: ignore[arg-type] + folder=folder, + action_id=55, + search_space_id=1, + turn_id="t-1", + deferred_dispatches=deferred, + ) + + assert dispatched == [] + assert deferred == [55] diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py new file mode 100644 index 000000000..caaec3114 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py @@ -0,0 +1,139 @@ +"""Unit tests for ``KnowledgeTreeMiddleware`` rendering. + +The empty-folder marker is critical UX: without it, the LLM cannot +distinguish a leaf folder containing one document from a leaf folder +that has no descendants at all, and ends up firing ``rmdir`` on +non-empty folders. These tests pin the rendering contract so that +contract cannot silently regress. +""" + +from __future__ import annotations + +from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT + + +def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]: + return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths) + + +class TestComputeNonEmptyFolders: + def test_folder_with_direct_document_is_non_empty(self): + folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"] + doc_paths = [ + f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml", + ] + non_empty = _compute(folder_paths, doc_paths) + assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty + + def test_truly_empty_leaf_folder_is_not_non_empty(self): + folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"] + doc_paths: list[str] = [] + assert _compute(folder_paths, doc_paths) == set() + + def test_documents_propagate_up_to_all_ancestors(self): + folder_paths = [ + f"{DOCUMENTS_ROOT}/A", + f"{DOCUMENTS_ROOT}/A/B", + f"{DOCUMENTS_ROOT}/A/B/C", + ] + doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"] + non_empty = _compute(folder_paths, doc_paths) + assert non_empty == { + f"{DOCUMENTS_ROOT}/A", + f"{DOCUMENTS_ROOT}/A/B", + f"{DOCUMENTS_ROOT}/A/B/C", + } + + def test_chain_with_subfolders_marks_only_leaf_empty(self): + # POSIX-like semantic: a folder is "empty" only if it has no + # immediate children (docs OR sub-folders). The model needs this + # because parallel ``rmdir`` calls all see the same starting state, + # so trying to rmdir a parent before its children is never safe. + folder_paths = [ + f"{DOCUMENTS_ROOT}/X", + f"{DOCUMENTS_ROOT}/X/Y", + f"{DOCUMENTS_ROOT}/X/Y/Z", + ] + non_empty = _compute(folder_paths, []) + # Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a + # sub-folder child, so they are non-empty and should NOT carry the + # ``(empty)`` marker. + assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"} + + def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self): + # Mirrors a real DB layout where every intermediate folder is + # materialized in the ``folders`` table. + folder_paths = [ + f"{DOCUMENTS_ROOT}/Travel", + f"{DOCUMENTS_ROOT}/Travel/Boarding Pass", + f"{DOCUMENTS_ROOT}/Travel/Notes", + ] + doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"] + non_empty = _compute(folder_paths, doc_paths) + # ``Travel`` is non-empty because it has children, ``Notes`` is non-empty + # because of the doc, but ``Boarding Pass`` (sibling leaf) is empty. + assert f"{DOCUMENTS_ROOT}/Travel" in non_empty + assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty + assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty + + +class TestFormatTreeRendering: + """Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't.""" + + def _render( + self, + folder_paths: list[str], + doc_specs: list[dict], + ) -> str: + from app.agents.new_chat.path_resolver import PathIndex + + index = PathIndex( + folder_paths={i + 1: p for i, p in enumerate(folder_paths)}, + ) + + class _Row: + def __init__(self, **kw): + self.__dict__.update(kw) + + docs = [_Row(**spec) for spec in doc_specs] + + mw = KnowledgeTreeMiddleware( + search_space_id=1, + filesystem_mode=None, # type: ignore[arg-type] + ) + return mw._format_tree(index, docs) + + def test_renders_empty_marker_only_for_truly_empty_folders(self): + # Reproduces the failure scenario from the bug report: + # ``Boarding Pass`` is empty (its only doc was just deleted), while + # ``Tax Returns`` still has ``federal.pdf``. All intermediate + # folders are present in the index, mirroring the real DB layout. + folder_paths = [ + "/documents/File Upload", + "/documents/File Upload/2026-04-08", + "/documents/File Upload/2026-04-08/Travel", + "/documents/File Upload/2026-04-08/Travel/Boarding Pass", + "/documents/File Upload/2026-04-15", + "/documents/File Upload/2026-04-15/Finance", + "/documents/File Upload/2026-04-15/Finance/Tax Returns", + ] + tax_returns_folder_id = ( + folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns") + + 1 + ) + rendered = self._render( + folder_paths=folder_paths, + doc_specs=[ + { + "id": 100, + "title": "federal.pdf", + "folder_id": tax_returns_folder_id, + }, + ], + ) + assert "Boarding Pass/ (empty)" in rendered + assert "Tax Returns/ (empty)" not in rendered + # Intermediate ancestors of the doc must NOT be marked empty. + assert "Finance/ (empty)" not in rendered + assert "2026-04-15/ (empty)" not in rendered diff --git a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py index 7dfc68402..6e81ecf8e 100644 --- a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py @@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path): assert write.error is not None assert "parent directory" in write.error assert not (tmp_path / "tempoo").exists() + + +def test_local_backend_delete_file_success(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "delete-me.md").write_text("bye") + + res = backend.delete_file("/delete-me.md") + assert res.error is None + assert res.path == "/delete-me.md" + assert not (tmp_path / "delete-me.md").exists() + + +def test_local_backend_delete_file_rejects_directory(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "subdir").mkdir() + + res = backend.delete_file("/subdir") + assert res.error is not None + assert "directory" in res.error + assert (tmp_path / "subdir").exists() + + +def test_local_backend_delete_file_missing_returns_error(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + + res = backend.delete_file("/nope.md") + assert res.error is not None + assert "not found" in res.error + + +def test_local_backend_rmdir_success(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "empty").mkdir() + + res = backend.rmdir("/empty") + assert res.error is None + assert res.path == "/empty" + assert not (tmp_path / "empty").exists() + + +def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "withkid").mkdir() + (tmp_path / "withkid" / "child.md").write_text("x") + + res = backend.rmdir("/withkid") + assert res.error is not None + assert "not empty" in res.error + assert (tmp_path / "withkid" / "child.md").exists() + + +def test_local_backend_rmdir_rejects_file(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "f.md").write_text("x") + + res = backend.rmdir("/f.md") + assert res.error is not None + assert "not a directory" in res.error + + +def test_local_backend_rmdir_rejects_root(tmp_path: Path): + """``rmdir /`` MUST fail. The exact error wording comes from + ``_resolve_virtual`` (root resolves to outside the sandbox); what + matters is that the call returns an error and does NOT delete the + sandbox root on disk.""" + backend = LocalFolderBackend(str(tmp_path)) + + res = backend.rmdir("/") + assert res.error is not None + assert "Invalid path" in res.error or "root" in res.error + assert tmp_path.exists() diff --git a/surfsense_backend/tests/unit/routes/__init__.py b/surfsense_backend/tests/unit/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py b/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py new file mode 100644 index 000000000..709014d55 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py @@ -0,0 +1,143 @@ +"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``. + +The regenerate route's edit-from-position path introduces: +* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples + newest-first and picks the first one whose ``metadata["turn_id"]`` + differs from the edited turn. That checkpoint is the rewind target + (state immediately before the edited turn started). +* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions`` + with a validator that prevents callers from requesting a revert pass + without specifying which turn to roll back. + +These are pure-Python helpers that don't need a live DB, so we exercise +them with a small ``CheckpointTuple``-shaped namespace and direct +schema instantiation. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id +from app.schemas.new_chat import RegenerateRequest + + +def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace: + """Build a fake ``CheckpointTuple`` with the metadata shape we read.""" + return SimpleNamespace( + config={"configurable": {"checkpoint_id": checkpoint_id}}, + metadata={"turn_id": turn_id} if turn_id is not None else {}, + ) + + +class TestFindPreTurnCheckpointId: + def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None: + # Newest-first: T2 is the most-recent turn. The latest non-T2 + # checkpoint (cp2) is the rewind target — state immediately + # before T2 began. + tuples = [ + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None: + # Regression for the bug where walking newest-first returned the + # FIRST cp with ``turn_id != target`` — which is one of the + # later-turn checkpoints, NOT the pre-turn boundary. Editing + # T2 must rewind to the latest T1 checkpoint (cp2), not to the + # latest T3 checkpoint (cp6). + tuples = [ + _cp("cp6", "T3"), + _cp("cp5", "T3"), + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_returns_none_when_editing_first_turn(self) -> None: + # No pre-turn boundary exists; caller is expected to fall back + # to the oldest checkpoint or special-case "first turn of the + # thread". + tuples = [ + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None + + def test_returns_none_when_only_edited_turn_present(self) -> None: + tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None + + def test_returns_none_for_empty_history(self) -> None: + assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None + + def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None: + # Checkpoints written before migration 136 have no + # ``metadata.turn_id``. They should be eligible rewind targets + # — they came before the + # edited turn began. + tuples = [ + _cp("cp3", "T2"), + SimpleNamespace( + config={"configurable": {"checkpoint_id": "cp2"}}, + metadata=None, + ), + _cp("cp1", "T1"), + ] + # Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked, + # then cp3(T2) crosses the boundary -> return cp2. + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None: + # If a checkpoint tuple's ``config["configurable"]`` is missing + # the ``checkpoint_id`` key (corrupt / partial), we keep the + # last known good target instead of crashing. + broken = SimpleNamespace( + config={"configurable": {}}, metadata={"turn_id": "T1"} + ) + tuples = [ + _cp("cp3", "T2"), + broken, + _cp("cp1", "T1"), + ] + # cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1. + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1" + + +class TestRegenerateRequestValidation: + def test_revert_actions_requires_from_message_id(self) -> None: + with pytest.raises(Exception) as exc: + RegenerateRequest( + search_space_id=1, + user_query="hi", + revert_actions=True, + ) + msg = str(exc.value).lower() + assert "from_message_id" in msg + + def test_from_message_id_without_revert_is_allowed(self) -> None: + req = RegenerateRequest( + search_space_id=1, + user_query="hi", + from_message_id=42, + ) + assert req.from_message_id == 42 + assert req.revert_actions is False + + def test_revert_actions_with_from_message_id_passes(self) -> None: + req = RegenerateRequest( + search_space_id=1, + user_query="hi", + from_message_id=42, + revert_actions=True, + ) + assert req.revert_actions is True diff --git a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py new file mode 100644 index 000000000..1e1cbffb3 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py @@ -0,0 +1,530 @@ +"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + +The per-turn batch revert route walks rows in reverse ``created_at`` +order, reverts each independently, and returns a per-action result +list. Partial success is normal — the response status +is ``"partial"`` whenever any row could not be reverted, but we never +collapse the whole batch into a 4xx. + +These tests stub ``load_thread`` / ``revert_action`` and feed a fake +session, so they exercise the route's dispatch logic without a real DB. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.routes import agent_revert_route +from app.services.revert_service import RevertOutcome + + +@dataclass +class _FakeAction: + id: int + tool_name: str + user_id: str | None = "u1" + reverse_of: int | None = None + error: dict | None = None + + +@dataclass +class _FakeUser: + id: str = "u1" + + +@dataclass +class _ScalarResult: + rows: list[Any] + + def first(self) -> Any: + return self.rows[0] if self.rows else None + + def all(self) -> list[Any]: + return list(self.rows) + + +@dataclass +class _Result: + rows: list[Any] = field(default_factory=list) + + def scalars(self) -> _ScalarResult: + return _ScalarResult(self.rows) + + def all(self) -> list[Any]: + # ``_was_already_reverted_batch`` calls ``.all()`` directly on + # the row-tuple result (no ``.scalars()`` indirection). The + # rows queued for that helper are list[(revert_id, original_id)]. + return list(self.rows) + + +class _FakeNestedCtx: + """Async context manager that mimics ``session.begin_nested()``. + + The route raises a sentinel exception inside this block to roll back + bad rows. We just pass the exception through. + """ + + async def __aenter__(self) -> _FakeNestedCtx: + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + # Returning False (or None) propagates the exception; the route + # catches its own sentinel above this layer. + return False + + +class _FakeSession: + """Minimal AsyncSession stand-in for the revert-turn route. + + Holds a queue of result objects; each ``execute(...)`` pops the next + one. The route calls ``execute`` exactly once per query so this maps + cleanly onto the assertion order of the test. + """ + + def __init__(self) -> None: + self._results: list[_Result] = [] + self.committed = False + self.rolled_back = False + # Count execute() calls to assert "no N+1 reverts". + self.execute_call_count = 0 + + def queue(self, *results: _Result) -> None: + self._results.extend(results) + + async def execute(self, _stmt: Any) -> _Result: + self.execute_call_count += 1 + if not self._results: + return _Result(rows=[]) + return self._results.pop(0) + + def begin_nested(self) -> _FakeNestedCtx: + return _FakeNestedCtx() + + async def commit(self) -> None: + self.committed = True + + async def rollback(self) -> None: + self.rolled_back = True + + +def _enabled_flags() -> AgentFeatureFlags: + return AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + enable_revert_route=True, + ) + + +@pytest.fixture +def patch_get_flags(): + def _patch(flags: AgentFeatureFlags): + return patch( + "app.routes.agent_revert_route.get_flags", + return_value=flags, + ) + + return _patch + + +class TestFlagGuard: + @pytest.mark.asyncio + async def test_returns_503_when_revert_route_disabled( + self, patch_get_flags + ) -> None: + flags = AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + enable_revert_route=False, + ) + session = _FakeSession() + with patch_get_flags(flags), pytest.raises(Exception) as exc: + await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="42:1700000000000", + session=session, + user=_FakeUser(), + ) + assert getattr(exc.value, "status_code", None) == 503 + + +class TestRevertTurnDispatch: + @pytest.mark.asyncio + async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None: + session = _FakeSession() + session.queue(_Result(rows=[])) # rows query returns nothing + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-empty", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.total == 0 + assert response.results == [] + assert session.committed is True + + @pytest.mark.asyncio + async def test_walks_rows_in_reverse_and_reverts_each( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=10, tool_name="rm"), + _FakeAction(id=9, tool_name="write_file"), + _FakeAction(id=8, tool_name="mkdir"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched ``_was_already_reverted_batch`` probe replaces + # the previous N per-row SELECTs. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + return RevertOutcome( + status="ok", + message=f"reverted-{action.id}", + new_action_id=100 + action.id, + ) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-3", + session=session, + user=_FakeUser(), + ) + + assert response.status == "ok" + assert response.total == 3 + assert response.reverted == 3 + assert [r.action_id for r in response.results] == [10, 9, 8] + assert all(r.status == "reverted" for r in response.results) + assert response.results[0].new_action_id == 110 + # Only TWO ``execute`` calls regardless of the row count: one + # for the rows query, one for the batched + # ``_was_already_reverted_batch`` probe. Regression guard + # against re-introducing the per-row N+1 lookup. + assert session.execute_call_count == 2, ( + "revert-turn loop must batch idempotency probes; got " + f"{session.execute_call_count} execute() calls (expected 2)." + ) + + @pytest.mark.asyncio + async def test_already_reverted_rows_are_marked_idempotent( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=5, tool_name="edit_file")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch probe returns ``[(revert_id, original_id)]``. + session.queue(_Result(rows=[(42, 5)])) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-i", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.already_reverted == 1 + assert response.results[0].status == "already_reverted" + assert response.results[0].new_action_id == 42 + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_revert_action_skips_existing_revert_rows( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)] + session = _FakeSession() + session.queue(_Result(rows=rows)) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-rev", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.results[0].status == "skipped" + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_partial_success_when_some_rows_not_reversible( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=2, tool_name="send_email"), + _FakeAction(id=1, tool_name="edit_file"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched idempotency probe. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.tool_name == "send_email": + return RevertOutcome( + status="not_reversible", + message="connector revert not yet implemented", + ) + return RevertOutcome(status="ok", message="ok", new_action_id=500) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-mix", + session=session, + user=_FakeUser(), + ) + assert response.status == "partial" + assert response.reverted == 1 + assert response.not_reversible == 1 + statuses = sorted(r.status for r in response.results) + assert statuses == ["not_reversible", "reverted"] + + @pytest.mark.asyncio + async def test_unexpected_exception_marks_row_failed_not_batch( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=20, tool_name="edit_file"), + _FakeAction(id=21, tool_name="edit_file"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched idempotency probe. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.id == 20: + raise RuntimeError("disk on fire") + return RevertOutcome(status="ok", message="ok", new_action_id=999) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-fail", + session=session, + user=_FakeUser(), + ) + assert response.status == "partial" + assert response.failed == 1 + assert response.reverted == 1 + bad = next(r for r in response.results if r.action_id == 20) + assert bad.status == "failed" + assert "disk on fire" in (bad.error or "") + good = next(r for r in response.results if r.action_id == 21) + assert good.status == "reverted" + + @pytest.mark.asyncio + async def test_permission_denied_when_other_user_owns_action( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch idempotency probe (no prior reverts). + session.queue(_Result(rows=[])) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-perm", + session=session, + user=_FakeUser(id="not-owner"), + ) + assert response.status == "partial" + assert response.results[0].status == "permission_denied" + # ``permission_denied`` has its own dedicated counter so the + # response invariant ``total == sum(counters)`` always holds + # without overloading ``not_reversible`` (which historically + # absorbed this case and confused frontend toasts). + assert response.permission_denied == 1 + assert response.not_reversible == 0 + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_counter_invariant_holds_across_mixed_outcomes( + self, patch_get_flags + ) -> None: + """Every row is accounted for in EXACTLY ONE counter. + + Mixes one of every supported outcome (reverted, already_reverted, + not_reversible, permission_denied, failed, skipped) and asserts + that the sum of counters equals ``response.total``. + """ + rows = [ + _FakeAction(id=10, tool_name="edit_file"), # ok + _FakeAction(id=9, tool_name="edit_file"), # already_reverted + _FakeAction(id=8, tool_name="send_email"), # not_reversible + _FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied + _FakeAction(id=6, tool_name="edit_file"), # failed + _FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched probe; only id=9 has a prior revert. + # Schema: list[(revert_id, original_id)]. + session.queue(_Result(rows=[(42, 9)])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.id == 10: + return RevertOutcome(status="ok", message="ok", new_action_id=500) + if action.id == 8: + return RevertOutcome( + status="not_reversible", + message="connector revert not yet implemented", + ) + if action.id == 6: + raise RuntimeError("boom") + raise AssertionError(f"unexpected revert call for {action.id}") + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, + "revert_action", + AsyncMock(side_effect=_fake_revert), + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-mixed-all", + session=session, + user=_FakeUser(), # only id=7 has a different user_id + ) + + assert response.total == len(rows) == 6 + bucket_sum = ( + response.reverted + + response.already_reverted + + response.not_reversible + + response.permission_denied + + response.failed + + response.skipped + ) + assert bucket_sum == response.total, ( + "Counter invariant broken: total " + f"({response.total}) != sum of counters ({bucket_sum}). " + f"Counters: reverted={response.reverted}, " + f"already_reverted={response.already_reverted}, " + f"not_reversible={response.not_reversible}, " + f"permission_denied={response.permission_denied}, " + f"failed={response.failed}, skipped={response.skipped}" + ) + assert response.reverted == 1 + assert response.already_reverted == 1 + assert response.not_reversible == 1 + assert response.permission_denied == 1 + assert response.failed == 1 + assert response.skipped == 1 + + @pytest.mark.asyncio + async def test_integrity_error_translates_to_already_reverted( + self, patch_get_flags + ) -> None: + """The partial unique index on ``reverse_of`` raises + ``IntegrityError`` when a concurrent revert wins the race against + the pre-flight ``_was_already_reverted`` SELECT. The route MUST + recover by re-querying for the winning revert id and returning + ``status="already_reverted"`` (not ``"failed"``) so racing + clients see consistent idempotent semantics. + """ + from sqlalchemy.exc import IntegrityError + + rows = [_FakeAction(id=33, tool_name="edit_file")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch pre-flight probe: nothing yet (we'll race). + session.queue(_Result(rows=[])) + # Post-IntegrityError fallback uses the SCALAR + # ``_was_already_reverted`` (single-id lookup) so it pulls + # ``[777]`` via ``.scalars().first()``. + session.queue(_Result(rows=[777])) + + async def _racing_revert(_session, *, action, requester_user_id): + raise IntegrityError("INSERT", {}, Exception("dup reverse_of")) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, + "revert_action", + AsyncMock(side_effect=_racing_revert), + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-race", + session=session, + user=_FakeUser(), + ) + + assert response.failed == 0, ( + "IntegrityError must NOT surface as a failed row; the unique " + "index is the durable expression of idempotency." + ) + assert response.already_reverted == 1 + assert response.results[0].status == "already_reverted" + assert response.results[0].new_action_id == 777 diff --git a/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py b/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py new file mode 100644 index 000000000..95314741a --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py @@ -0,0 +1,370 @@ +"""Unit tests for the filesystem-tool branches of ``revert_service``. + +Covers: + +* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document + branch (``"rmdir".startswith("rm")`` would mis-route under the legacy + prefix-based dispatch). +* ``rm`` revert re-INSERTs a fresh document from the snapshot, including + re-creating chunks. Falls back to ``(folder_id_before, title_before)`` + when ``metadata_before["virtual_path"]`` is missing. +* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the + document. +* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot. +* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable`` + when the folder gained children. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from app.services import revert_service + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + revert_service, + "embed_texts", + lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts], + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeResult: + def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None: + self._rows = rows or [] + self._scalar = scalar + + def all(self) -> list[Any]: + return list(self._rows) + + def scalar_one_or_none(self) -> Any: + return self._scalar + + def scalars(self) -> Any: + return _FakeScalarsProxy(self._rows) + + +class _FakeScalarsProxy: + def __init__(self, rows: list[Any]) -> None: + self._rows = rows + + def first(self) -> Any: + return self._rows[0] if self._rows else None + + +class _FakeSession: + def __init__(self) -> None: + self.execute = AsyncMock() + self.added: list[Any] = [] + self.deleted: list[Any] = [] + self.flush = AsyncMock() + # session.get(Model, pk) lookup + self.get = AsyncMock(return_value=None) + + async def _flush_assigning_ids() -> None: + for obj in self.added: + if getattr(obj, "id", None) is None: + obj.id = 999 + + self.flush.side_effect = _flush_assigning_ids + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: list[Any]) -> None: + self.added.extend(objs) + + +def _action(*, tool_name: str, action_id: int = 7): + return MagicMock( + id=action_id, + tool_name=tool_name, + thread_id=1, + search_space_id=2, + user_id="user-1", + reverse_descriptor=None, + ) + + +def _doc_revision( + *, + document_id: int | None = None, + content_before: str | None = "old content", + title_before: str | None = "notes.md", + folder_id_before: int | None = 5, + chunks_before: list[dict[str, str]] | None = None, + metadata_before: dict[str, str] | None = None, +): + revision = MagicMock() + revision.id = 100 + revision.document_id = document_id + revision.search_space_id = 2 + revision.content_before = content_before + revision.title_before = title_before + revision.folder_id_before = folder_id_before + revision.chunks_before = chunks_before or [] + revision.metadata_before = metadata_before + return revision + + +def _folder_revision( + *, + folder_id: int | None = None, + name_before: str | None = "team", + parent_id_before: int | None = None, + position_before: str | None = "a0", +): + revision = MagicMock() + revision.id = 200 + revision.folder_id = folder_id + revision.search_space_id = 2 + revision.name_before = name_before + revision.parent_id_before = parent_id_before + revision.position_before = position_before + return revision + + +# --------------------------------------------------------------------------- +# Exact-name dispatch regression guards +# --------------------------------------------------------------------------- + + +class TestExactDispatch: + """Regression: ``rmdir`` MUST NOT route to the document branch.""" + + @pytest.mark.asyncio + async def test_rmdir_does_not_misroute_to_document(self) -> None: + # If dispatch used `startswith("rm")` we'd hit the document branch + # here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`. + session = _FakeSession() + action = _action(tool_name="rmdir") + # No folder revisions exist for this action. + session.execute.return_value = _FakeResult(rows=[]) + outcome = await revert_service.revert_action( + session, # type: ignore[arg-type] + action=action, + requester_user_id="user-1", + ) + assert outcome.status == "not_reversible" + assert "folder_revisions" in outcome.message + + def test_dispatch_sets_split_doc_and_folder(self) -> None: + # Static guards on the dispatch tables themselves so a future + # refactor doesn't accidentally reintroduce the prefix bug. + assert "rm" in revert_service._DOC_TOOLS + assert "rmdir" in revert_service._FOLDER_TOOLS + assert "rmdir" not in revert_service._DOC_TOOLS + assert "rm" not in revert_service._FOLDER_TOOLS + # ``move_file`` lives only in document tools (it's a doc rename). + assert "move_file" in revert_service._DOC_TOOLS + assert "move_file" not in revert_service._FOLDER_TOOLS + + +# --------------------------------------------------------------------------- +# rm revert (re-INSERT) +# --------------------------------------------------------------------------- + + +class TestRmRevert: + @pytest.mark.asyncio + async def test_re_inserts_document_with_chunks(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=None, # row was hard-deleted + content_before="hello world", + title_before="x.md", + folder_id_before=None, + chunks_before=[{"content": "alpha"}, {"content": "beta"}], + metadata_before={"virtual_path": "/documents/x.md"}, + ) + # No collision check hit and the resulting query returns nothing. + session.execute.return_value = _FakeResult(scalar=None) + + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + + assert outcome.status == "ok" + # New Document + 2 chunks must have been added. + from app.db import Chunk, Document + + added_docs = [obj for obj in session.added if isinstance(obj, Document)] + added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)] + assert len(added_docs) == 1 + assert added_docs[0].title == "x.md" + assert len(added_chunks) == 2 + # Snapshot was repointed at the new doc id so a follow-up revert works. + assert revision.document_id == added_docs[0].id + + @pytest.mark.asyncio + async def test_falls_back_to_folder_id_and_title_for_virtual_path( + self, + ) -> None: + session = _FakeSession() + # Snapshot with NO metadata_before — the fallback path must kick in. + revision = _doc_revision( + document_id=None, + content_before="hello", + title_before="cap.md", + folder_id_before=42, + chunks_before=[], + metadata_before=None, + ) + # session.get(Folder, 42) returns a folder with a name. + folder = MagicMock() + folder.name = "team" + folder.parent_id = None + # First .get is for the folder lookup in the path-derivation. + session.get = AsyncMock(return_value=folder) + session.execute.return_value = _FakeResult(scalar=None) + + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + + @pytest.mark.asyncio + async def test_falls_back_to_root_path_when_no_folder( + self, + ) -> None: + """metadata_before is None and folder_id_before is None still + resolves: title fallback yields ``/documents/`` so revert + proceeds at the root of the documents tree.""" + session = _FakeSession() + revision = _doc_revision( + document_id=None, + content_before="hello", + title_before="x.md", + folder_id_before=None, + metadata_before=None, + ) + # No collision in the documents tree at /documents/x.md. + session.execute.return_value = _FakeResult(scalar=None) + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + + @pytest.mark.asyncio + async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=None, + content_before="hi", + title_before="x.md", + folder_id_before=None, + metadata_before={"virtual_path": "/documents/x.md"}, + ) + # SELECT for unique_identifier_hash collision hits an existing row. + session.execute.return_value = _FakeResult(scalar=42) + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "tool_unavailable" + assert "collide" in outcome.message + + +# --------------------------------------------------------------------------- +# write_file create revert (DELETE) +# --------------------------------------------------------------------------- + + +class TestWriteFileCreateRevert: + @pytest.mark.asyncio + async def test_deletes_created_doc(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=99, + content_before=None, # marker for "created in this action" + title_before=None, + ) + outcome = await revert_service._delete_created_document( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + # Exactly one DELETE was issued. + assert session.execute.await_count == 1 + + +# --------------------------------------------------------------------------- +# rmdir revert (re-INSERT folder) +# --------------------------------------------------------------------------- + + +class TestRmdirRevert: + @pytest.mark.asyncio + async def test_re_inserts_folder_from_snapshot(self) -> None: + session = _FakeSession() + revision = _folder_revision( + folder_id=None, + name_before="team", + parent_id_before=None, + position_before="a0", + ) + outcome = await revert_service._reinsert_folder_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + from app.db import Folder + + assert outcome.status == "ok" + added_folders = [obj for obj in session.added if isinstance(obj, Folder)] + assert len(added_folders) == 1 + assert added_folders[0].name == "team" + assert revision.folder_id == added_folders[0].id + + +# --------------------------------------------------------------------------- +# mkdir revert (DELETE folder) +# --------------------------------------------------------------------------- + + +class TestMkdirRevert: + @pytest.mark.asyncio + async def test_deletes_empty_folder(self) -> None: + session = _FakeSession() + revision = _folder_revision(folder_id=42) + # Both the doc-existence check and the child-folder check return None. + session.execute.side_effect = [ + _FakeResult(scalar=None), # docs + _FakeResult(scalar=None), # children + _FakeResult(scalar=None), # delete (no return value) + ] + outcome = await revert_service._delete_created_folder( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + # 3 executes: docs check, children check, delete. + assert session.execute.await_count == 3 + + @pytest.mark.asyncio + async def test_reports_tool_unavailable_when_folder_has_children(self) -> None: + session = _FakeSession() + revision = _folder_revision(folder_id=42) + # First check (docs) returns "row found". + session.execute.return_value = _FakeResult(scalar=1) + outcome = await revert_service._delete_created_folder( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "tool_unavailable" + assert "no longer empty" in outcome.message diff --git a/surfsense_backend/tests/unit/tasks/__init__.py b/surfsense_backend/tests/unit/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/__init__.py b/surfsense_backend/tests/unit/tasks/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py new file mode 100644 index 000000000..7f32bf456 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -0,0 +1,185 @@ +"""Unit tests for ``stream_new_chat._extract_chunk_parts``. + +Earlier versions only handled ``isinstance(chunk.content, str)`` and +silently dropped every other shape (Anthropic typed-block lists, +Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from +a few providers). These regression tests pin those four shapes plus the +defensive cases (``None`` chunk, mixed types, missing fields). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.tasks.chat.stream_new_chat import _extract_chunk_parts + + +@dataclass +class _FakeChunk: + """Minimal stand-in for ``AIMessageChunk`` used in unit tests.""" + + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +class TestStringContent: + def test_plain_string_content_extracts_as_text(self) -> None: + chunk = _FakeChunk(content="hello world") + out = _extract_chunk_parts(chunk) + assert out["text"] == "hello world" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] + + def test_empty_string_content_yields_empty_text(self) -> None: + chunk = _FakeChunk(content="") + out = _extract_chunk_parts(chunk) + assert out["text"] == "" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] + + +class TestListContent: + def test_list_of_text_blocks_concatenates(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "Hello world" + assert out["reasoning"] == "" + + def test_mixed_text_and_reasoning_blocks(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "reasoning", "reasoning": "Let me think... "}, + {"type": "reasoning", "text": "still thinking."}, + {"type": "text", "text": "The answer is 42."}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "The answer is 42." + assert out["reasoning"] == "Let me think... still thinking." + + def test_tool_call_chunks_in_content_list_extracted(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text", "text": "Calling tool..."}, + { + "type": "tool_call_chunk", + "id": "call_123", + "name": "make_widget", + "args": '{"color":"red"}', + }, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "Calling tool..." + assert out["reasoning"] == "" + assert len(out["tool_call_chunks"]) == 1 + assert out["tool_call_chunks"][0]["id"] == "call_123" + assert out["tool_call_chunks"][0]["name"] == "make_widget" + + def test_tool_use_blocks_also_extracted(self) -> None: + """Some providers (Anthropic) emit ``type='tool_use'`` instead.""" + chunk = _FakeChunk( + content=[ + { + "type": "tool_use", + "id": "call_xyz", + "name": "search", + }, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["tool_call_chunks"] == [ + {"type": "tool_use", "id": "call_xyz", "name": "search"} + ] + + def test_unknown_block_types_are_ignored(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "image_url", "url": "https://example.com/x.png"}, + {"type": "text", "text": "ok"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "ok" + + def test_blocks_without_text_field_are_ignored(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text"}, # no text/content key + {"type": "text", "text": "kept"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "kept" + + +class TestAdditionalKwargsReasoning: + def test_reasoning_content_in_additional_kwargs(self) -> None: + """Some providers stash reasoning in ``additional_kwargs.reasoning_content``.""" + chunk = _FakeChunk( + content="visible answer", + additional_kwargs={"reasoning_content": "internal monologue"}, + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "visible answer" + assert out["reasoning"] == "internal monologue" + + def test_reasoning_appended_to_typed_block_reasoning(self) -> None: + chunk = _FakeChunk( + content=[{"type": "reasoning", "text": "from blocks. "}], + additional_kwargs={"reasoning_content": "from kwargs."}, + ) + out = _extract_chunk_parts(chunk) + assert out["reasoning"] == "from blocks. from kwargs." + + +class TestToolCallChunksAttribute: + def test_tool_call_chunks_attribute_extracted_alongside_string_content( + self, + ) -> None: + chunk = _FakeChunk( + content="streaming text", + tool_call_chunks=[ + {"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"} + ], + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "streaming text" + assert len(out["tool_call_chunks"]) == 1 + assert out["tool_call_chunks"][0]["id"] == "tc-9" + + def test_attribute_and_typed_block_chunks_both_collected(self) -> None: + chunk = _FakeChunk( + content=[ + { + "type": "tool_call_chunk", + "id": "from-block", + "name": "x", + } + ], + tool_call_chunks=[{"id": "from-attr", "name": "y"}], + ) + out = _extract_chunk_parts(chunk) + ids = [tcc.get("id") for tcc in out["tool_call_chunks"]] + assert ids == ["from-block", "from-attr"] + + +class TestDefensive: + @pytest.mark.parametrize( + "chunk_value", + [None, _FakeChunk(content=None), _FakeChunk(content=42)], + ) + def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None: + out = _extract_chunk_parts(chunk_value) + assert out["text"] == "" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 7773a438a..c2086e80a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -14,6 +14,13 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; +import { + agentActionsByChatTurnIdAtom, + markAgentActionRevertedAtom, + resetAgentActionMapAtom, + updateAgentActionReversibleAtom, + upsertAgentActionAtom, +} from "@/atoms/chat/agent-actions.atom"; import { clearTargetCommentIdAtom, currentThreadAtom, @@ -36,6 +43,11 @@ import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { + EditMessageDialog, + type EditMessageDialogChoice, +} from "@/components/assistant-ui/edit-message-dialog"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Thread } from "@/components/assistant-ui/thread"; import { @@ -55,14 +67,19 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { + addStepSeparator, addToolCall, + appendReasoning, appendText, buildContentForPersistence, buildContentForUI, type ContentPartsState, + endReasoning, FrameBatchedUpdater, + findToolCallIdByLcId, readSSEStream, type ThinkingStepData, + type ToolUIGate, updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; @@ -161,44 +178,38 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { } /** - * Tools that should render custom UI in the chat. + * Every tool call renders a card. The legacy + * ``BASE_TOOLS_WITH_UI`` allowlist used to drop unknown tool calls on the + * floor; we now route everything through ``ToolFallback``. Persisted + * payload size stays bounded because the backend's + * ``format_thinking_step`` summarisation and the + * ``result_length``-only default for unknown tools (see + * ``stream_new_chat.py``) keep the JSON from ballooning. */ -const BASE_TOOLS_WITH_UI = new Set([ - "web_search", - "generate_podcast", - "generate_report", - "generate_resume", - "generate_video_presentation", - "display_image", - "generate_image", - "delete_notion_page", - "create_notion_page", - "update_notion_page", - "create_linear_issue", - "update_linear_issue", - "delete_linear_issue", - "create_google_drive_file", - "delete_google_drive_file", - "create_onedrive_file", - "delete_onedrive_file", - "create_dropbox_file", - "delete_dropbox_file", - "create_calendar_event", - "update_calendar_event", - "delete_calendar_event", - "create_gmail_draft", - "update_gmail_draft", - "send_gmail_email", - "trash_gmail_email", - "create_jira_issue", - "update_jira_issue", - "delete_jira_issue", - "create_confluence_page", - "update_confluence_page", - "delete_confluence_page", - "execute", - // "write_todos", // Disabled for now -]); +const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; + +/** + * When a streamed message is persisted, the backend returns the durable + * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it + * into the assistant-ui message metadata so the per-turn "Revert turn" + * button can scope to this turn's actions even after a full chat reload. + */ +function mergeChatTurnIdIntoMessage( + msg: ThreadMessageLike, + turnId: string | null | undefined +): ThreadMessageLike { + if (!turnId) return msg; + const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; + const existingCustom = existingMeta.custom ?? {}; + if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; + return { + ...msg, + metadata: { + ...existingMeta, + custom: { ...existingCustom, chatTurnId: turnId }, + }, + }; +} export default function NewChatPage() { const params = useParams(); @@ -215,7 +226,7 @@ export default function NewChatPage() { assistantMsgId: string; interruptData: Record<string, unknown>; } | null>(null); - const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const toolsWithUI = TOOLS_WITH_UI_ALL; // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -235,6 +246,25 @@ export default function NewChatPage() { const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); + // Agent action log SSE side-channel. + const upsertAgentAction = useSetAtom(upsertAgentActionAtom); + const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom); + const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom); + const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom); + // Chat-turn-keyed action map for the edit-from-position pre-flight + // that decides whether to show the confirmation dialog. + const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom); + // Edit dialog state. Holds the message id being edited and + // the (already extracted) regenerate args so we can resume the edit + // after the user picks "revert all" / "continue" / "cancel". + const [editDialogState, setEditDialogState] = useState<{ + fromMessageId: number; + userQuery: string | null; + userMessageContent: ThreadMessageLike["content"]; + userImages: NewChatUserImagePayload[]; + downstreamReversibleCount: number; + downstreamTotalCount: number; + } | null>(null); // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); @@ -327,6 +357,7 @@ export default function NewChatPage() { clearPlanOwnerRegistry(); closeReportPanel(); closeEditorPanel(); + resetAgentActionMap(); try { if (urlChatId > 0) { @@ -395,6 +426,7 @@ export default function NewChatPage() { removeChatTab, searchSpaceId, tokenUsageStore, + resetAgentActionMap, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) @@ -655,11 +687,14 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; let tokenUsageData: Record<string, unknown> | null = null; + // Captured from ``data-turn-info`` at stream start. + let streamedChatTurnId: string | null = null; // Add placeholder assistant message setMessages((prev) => [ @@ -752,21 +787,52 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": { if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); @@ -774,7 +840,10 @@ export default function NewChatPage() { } case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); markInterruptsCompleted(contentParts); if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { const idx = toolCallIndices.get(parsed.toolCallId); @@ -880,6 +949,50 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId: currentThreadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -900,13 +1013,18 @@ export default function NewChatPage() { role: "assistant", content: finalContent, token_usage: tokenUsageData ?? undefined, + turn_id: streamedChatTurnId, }); // Update message ID from temporary to database ID so comments work immediately const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); // Update pending interrupt with the new persisted message ID @@ -929,7 +1047,9 @@ export default function NewChatPage() { const hasContent = contentParts.some( (part) => (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) ); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); @@ -937,12 +1057,17 @@ export default function NewChatPage() { const savedMessage = await appendMessage(currentThreadId, { role: "assistant", content: partialContent, + turn_id: streamedChatTurnId, }); // Update message ID from temporary to database ID const newMsgId = `msg-${savedMessage.id}`; setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); } catch (err) { console.error("Failed to persist partial assistant message:", err); @@ -1030,10 +1155,13 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; let tokenUsageData: Record<string, unknown> | null = null; + // Captured from ``data-turn-info`` at stream start. + let streamedChatTurnId: string | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1136,8 +1264,34 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; @@ -1145,6 +1299,7 @@ export default function NewChatPage() { if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, }); } else { addToolCall( @@ -1152,7 +1307,9 @@ export default function NewChatPage() { toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); @@ -1161,6 +1318,7 @@ export default function NewChatPage() { case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, }); markInterruptsCompleted(contentParts); batcher.flush(); @@ -1222,6 +1380,50 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId: resumeThreadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1241,11 +1443,16 @@ export default function NewChatPage() { role: "assistant", content: finalContent, token_usage: tokenUsageData ?? undefined, + turn_id: streamedChatTurnId, }); const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); } catch (err) { console.error("Failed to persist resumed assistant message:", err); @@ -1340,6 +1547,12 @@ export default function NewChatPage() { editExtras?: { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; + }, + editFromPosition?: { + /** Message id (numeric, parsed from ``msg-<n>``) to rewind to. */ + fromMessageId?: number | null; + /** When true, revert reversible downstream actions before stream. */ + revertActions?: boolean; } ) => { if (!threadId) { @@ -1384,9 +1597,20 @@ export default function NewChatPage() { userQueryToDisplay = newUserQuery; } - // Remove the last two messages (user + assistant) from the UI immediately - // The backend will also delete them from the database + // Remove downstream messages from the UI immediately. The + // backend will also delete them from the database. + // + // When an explicit ``fromMessageId`` is passed, slice from + // that message forward; otherwise fall back to the legacy + // "drop the last 2" behaviour. setMessages((prev) => { + if (editFromPosition?.fromMessageId != null) { + const targetId = `msg-${editFromPosition.fromMessageId}`; + const sliceIndex = prev.findIndex((m) => m.id === targetId); + if (sliceIndex >= 0) { + return prev.slice(0, sliceIndex); + } + } if (prev.length >= 2) { return prev.slice(0, -2); } @@ -1406,11 +1630,16 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; const batcher = new FrameBatchedUpdater(); let tokenUsageData: Record<string, unknown> | null = null; + // Captured from ``data-turn-info`` at stream start; stamped + // onto persisted messages so future edits can locate the + // right LangGraph checkpoint. + let streamedChatTurnId: string | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1449,6 +1678,16 @@ export default function NewChatPage() { if (isEdit) { requestBody.user_images = editExtras?.userImages ?? []; } + // Explicit edit-from-arbitrary-position. Only send + // ``from_message_id`` / ``revert_actions`` when the + // caller asked for them; otherwise the backend keeps the + // legacy "last 2 messages" behaviour for back-compat. + if (editFromPosition?.fromMessageId != null) { + requestBody.from_message_id = editFromPosition.fromMessageId; + if (editFromPosition.revertActions) { + requestBody.revert_actions = true; + } + } const response = await fetch(getRegenerateUrl(threadId), { method: "POST", headers: { @@ -1481,28 +1720,62 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); break; case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); markInterruptsCompleted(contentParts); if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { const idx = toolCallIndices.get(parsed.toolCallId); @@ -1528,6 +1801,82 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + + case "data-revert-results": { + const summary = parsed.data; + // failureCount must include every "not undone" bucket + // (not_reversible, permission_denied, failed) so the + // toast's "X could not be rolled back" math matches + // the response invariant ``total === sum(counters)``. + // ``skipped`` rows are batch revert artefacts (revert + // rows themselves) and are not user-facing failures. + const failureCount = + summary.failed + summary.not_reversible + (summary.permission_denied ?? 0); + if (failureCount > 0) { + toast.warning( + `Pre-revert: ${summary.reverted}/${summary.total} undone, ${failureCount} could not be rolled back.` + ); + } else if (summary.reverted > 0) { + toast.success( + summary.reverted === 1 + ? "Reverted 1 downstream action before regenerating." + : `Reverted ${summary.reverted} downstream actions before regenerating.` + ); + } + for (const r of summary.results) { + if (r.status === "reverted" || r.status === "already_reverted") { + markAgentActionReverted({ + id: r.action_id, + newActionId: r.new_action_id ?? null, + }); + } + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1552,12 +1901,17 @@ export default function NewChatPage() { const savedUserMessage = await appendMessage(threadId, { role: "user", content: userContentToPersist, + turn_id: streamedChatTurnId, }); // Update user message ID to database ID const newUserMsgId = `msg-${savedUserMessage.id}`; setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) + prev.map((m) => + m.id === userMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) + : m + ) ); // Persist assistant message @@ -1565,12 +1919,17 @@ export default function NewChatPage() { role: "assistant", content: finalContent, token_usage: tokenUsageData ?? undefined, + turn_id: streamedChatTurnId, }); const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); trackChatResponseReceived(searchSpaceId, threadId); @@ -1608,7 +1967,14 @@ export default function NewChatPage() { [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] ); - // Handle editing a message - truncates history and regenerates with new query + // Handle editing a message - truncates history and regenerates with new query. + // + // When ``message.sourceId`` is set (the assistant-ui way to say + // "this edit replaces an older message"), we pin + // ``from_message_id`` so the backend rewinds to the right LangGraph + // checkpoint instead of relying on the legacy "last 2 messages" + // rewind. We also count downstream reversible actions and prompt the + // user to revert / continue / cancel before regenerating. const onEdit = useCallback( async (message: AppendMessage) => { const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []); @@ -1619,9 +1985,95 @@ export default function NewChatPage() { } const userMessageContent = message.content as unknown as ThreadMessageLike["content"]; - await handleRegenerate(queryForApi, { userMessageContent, userImages }); + + // ``sourceId`` per @assistant-ui/core's ``AppendMessage`` is + // "the ID of the message that was edited". Parse the numeric + // suffix so we can map it back to a DB row. + const sourceId = (message as { sourceId?: string }).sourceId; + const fromMessageId = + sourceId && /^msg-\d+$/.test(sourceId) + ? Number.parseInt(sourceId.replace(/^msg-/, ""), 10) + : null; + + if (fromMessageId == null) { + // No source id (or non-DB id) — fall back to today's + // last-2 behaviour. The user gets the legacy edit flow. + await handleRegenerate(queryForApi, { userMessageContent, userImages }); + return; + } + + // Pre-flight: count reversible downstream actions so we can + // auto-skip the dialog for harmless edits. + // + // "Downstream" means messages AFTER the edited one. The + // previous slice ``messages.slice(editedIndex)`` included + // the edited message itself in both the total + // count and the reversibility scan (any actions on the + // edited turn would be double-counted). Slice from + // ``editedIndex + 1`` so the dialog text matches reality: + // "N downstream messages will be dropped". + const editedIndex = messages.findIndex((m) => m.id === `msg-${fromMessageId}`); + let downstreamReversibleCount = 0; + let downstreamTotalCount = 0; + if (editedIndex >= 0) { + const downstream = messages.slice(editedIndex + 1); + downstreamTotalCount = downstream.length; + const seenTurns = new Set<string>(); + for (const m of downstream) { + const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; + const tid = meta.custom?.chatTurnId; + if (!tid || seenTurns.has(tid)) continue; + seenTurns.add(tid); + const turnActions = agentActionsByChatTurnId.get(tid) ?? []; + for (const a of turnActions) { + if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { + downstreamReversibleCount += 1; + } + } + } + } + + if (downstreamReversibleCount === 0) { + // Nothing to revert — submit silently. + await handleRegenerate( + queryForApi, + { userMessageContent, userImages }, + { fromMessageId, revertActions: false } + ); + return; + } + + setEditDialogState({ + fromMessageId, + userQuery: queryForApi, + userMessageContent, + userImages, + downstreamReversibleCount, + downstreamTotalCount, + }); }, - [handleRegenerate] + [handleRegenerate, messages, agentActionsByChatTurnId] + ); + + const handleEditDialogChoice = useCallback( + async (choice: EditMessageDialogChoice) => { + const pending = editDialogState; + if (!pending) return; + setEditDialogState(null); + if (choice === "cancel") return; + await handleRegenerate( + pending.userQuery, + { + userMessageContent: pending.userMessageContent, + userImages: pending.userImages, + }, + { + fromMessageId: pending.fromMessageId, + revertActions: choice === "revert", + } + ); + }, + [editDialogState, handleRegenerate] ); // Handle reloading/refreshing the last AI response @@ -1671,6 +2123,7 @@ export default function NewChatPage() { <TokenUsageProvider store={tokenUsageStore}> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div key={searchSpaceId} className="flex h-full overflow-hidden"> <div className="flex-1 flex flex-col min-w-0 overflow-hidden"> <Thread /> @@ -1679,6 +2132,15 @@ export default function NewChatPage() { <MobileEditorPanel /> <MobileHitlEditPanel /> </div> + <EditMessageDialog + open={editDialogState !== null} + onOpenChange={(open) => { + if (!open) setEditDialogState(null); + }} + downstreamReversibleCount={editDialogState?.downstreamReversibleCount ?? 0} + downstreamTotalCount={editDialogState?.downstreamTotalCount ?? 0} + onChoose={handleEditDialogChoice} + /> </AssistantRuntimeProvider> </TokenUsageProvider> ); diff --git a/surfsense_web/atoms/chat/agent-actions.atom.ts b/surfsense_web/atoms/chat/agent-actions.atom.ts new file mode 100644 index 000000000..7830c8751 --- /dev/null +++ b/surfsense_web/atoms/chat/agent-actions.atom.ts @@ -0,0 +1,194 @@ +"use client"; + +import { atom } from "jotai"; + +/** + * Minimal per-row projection of ``AgentActionLog`` that the tool card + * needs to decide whether to render a Revert button. + * + * Fields are deliberately a subset of the full ``AgentAction`` so the + * SSE side-channel (``data-action-log`` / ``data-action-log-updated``) + * can populate them without depending on the REST endpoint + * ``GET /threads/.../actions`` (which 503s when + * ``SURFSENSE_ENABLE_ACTION_LOG`` is off). + */ +export interface AgentActionLite { + id: number; + threadId: number | null; + lcToolCallId: string | null; + chatTurnId: string | null; + toolName: string; + reversible: boolean; + reverseDescriptorPresent: boolean; + error: boolean; + revertedByActionId: number | null; + isRevertAction: boolean; + createdAt: string | null; +} + +/** + * Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart + * tool-call.langchainToolCallId``). + */ +export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map()); + +/** + * Parallel map keyed off the synthetic chat-card ``toolCallId`` + * (``call_<run-id>``) so ``ToolFallback`` (which only receives the + * synthetic id from assistant-ui) can join its card to the action log. + * + * Both maps are kept in sync by ``upsertAgentActionAtom``. + */ +export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map()); + +/** + * Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer + * "how many reversible actions does this assistant turn contain?" in + * O(1). Each entry's array is ordered by insertion (which + * for a single turn matches ``created_at`` because action-log writes + * happen synchronously). + */ +export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map()); + +/** + * Action to upsert one ``AgentActionLite`` row. + * + * ``toolCallId`` is the synthetic card id (``call_<run-id>`` from + * ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the + * action is indexed under BOTH ids so the tool card can perform the + * lookup without going via the streaming state. + */ +export const upsertAgentActionAtom = atom( + null, + (_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => { + const { action, toolCallId } = payload; + const upsertInto = ( + prev: Map<string, AgentActionLite>, + key: string + ): Map<string, AgentActionLite> => { + const next = new Map(prev); + const existing = next.get(key); + next.set(key, { + ...action, + // Preserve the local "reverted" bookkeeping if a reversibility + // flip arrives AFTER the user already reverted via the REST + // route. We never want a stale ``reversible=true`` event to + // resurrect a Reverted card. + revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId, + isRevertAction: existing?.isRevertAction ?? action.isRevertAction, + }); + return next; + }; + if (action.lcToolCallId) { + set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string)); + } + if (toolCallId) { + set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId)); + } + if (action.chatTurnId) { + set(agentActionsByChatTurnIdAtom, (prev) => { + const next = new Map(prev); + const turnId = action.chatTurnId as string; + const existing = next.get(turnId) ?? []; + const priorEntry = existing.find((row) => row.id === action.id); + const merged: AgentActionLite = { + ...action, + revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId, + isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction, + }; + const others = existing.filter((row) => row.id !== action.id); + next.set(turnId, [...others, merged]); + return next; + }); + } + } +); + +function mutateById( + prev: Map<string, AgentActionLite>, + id: number, + mutator: (entry: AgentActionLite) => AgentActionLite +): Map<string, AgentActionLite> { + let mutated = false; + const next = new Map(prev); + for (const [key, value] of next) { + if (value.id === id) { + next.set(key, mutator(value)); + mutated = true; + } + } + return mutated ? next : prev; +} + +function mutateByIdInTurnIndex( + prev: Map<string, AgentActionLite[]>, + id: number, + mutator: (entry: AgentActionLite) => AgentActionLite +): Map<string, AgentActionLite[]> { + let mutated = false; + const next = new Map(prev); + for (const [key, list] of next) { + let listMutated = false; + const updated = list.map((row) => { + if (row.id === id) { + listMutated = true; + return mutator(row); + } + return row; + }); + if (listMutated) { + next.set(key, updated); + mutated = true; + } + } + return mutated ? next : prev; +} + +/** + * Action to flip an existing entry's ``reversible`` flag, keyed by the + * AgentActionLog row id (the SSE ``data-action-log-updated`` payload + * does NOT carry ``lcToolCallId``). + */ +export const updateAgentActionReversibleAtom = atom( + null, + (_get, set, payload: { id: number; reversible: boolean }) => { + const apply = (entry: AgentActionLite): AgentActionLite => ({ + ...entry, + reversible: payload.reversible, + }); + set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); + } +); + +/** Action to mark an existing entry as reverted (post-revert call). */ +export const markAgentActionRevertedAtom = atom( + null, + (_get, set, payload: { id: number; newActionId: number | null }) => { + const apply = (entry: AgentActionLite): AgentActionLite => ({ + ...entry, + revertedByActionId: payload.newActionId ?? -1, + }); + set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); + } +); + +/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */ +export const markAgentActionsRevertedBatchAtom = atom( + null, + (_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => { + for (const entry of payload.entries) { + set(markAgentActionRevertedAtom, entry); + } + } +); + +/** Reset all maps (e.g. when the active thread changes). */ +export const resetAgentActionMapAtom = atom(null, (_get, set) => { + set(agentActionByLcIdAtom, new Map()); + set(agentActionByToolCallIdAtom, new Map()); + set(agentActionsByChatTurnIdAtom, new Map()); +}); diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 6b9c2c87e..bfe0434b4 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -33,6 +33,8 @@ import { useAllCitationMetadata, } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; +import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; +import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; @@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => { <MessagePrimitive.Parts components={{ Text: MarkdownText, + Reasoning: ReasoningMessagePart, tools: { by_name: { generate_report: GenerateReportToolUI, @@ -699,6 +702,13 @@ const AssistantActionBar: FC = () => { const isLast = useAuiState((s) => s.message.isLast); const aui = useAui(); const api = useElectronAPI(); + // Surface the persisted ``chat_turn_id`` so the per-turn revert + // affordance can scope to just this message's actions. Streamed + // turns get their id once the assistant message is hydrated/finalised. + const chatTurnId = useAuiState(({ message }) => { + const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined; + return meta?.custom?.chatTurnId ?? null; + }); const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW; @@ -743,6 +753,9 @@ const AssistantActionBar: FC = () => { </TooltipIconButton> )} <MessageInfoDropdown /> + <div className="ml-auto"> + <RevertTurnButton chatTurnId={chatTurnId} /> + </div> </ActionBarPrimitive.Root> ); }; diff --git a/surfsense_web/components/assistant-ui/edit-message-dialog.tsx b/surfsense_web/components/assistant-ui/edit-message-dialog.tsx new file mode 100644 index 000000000..807f16fe7 --- /dev/null +++ b/surfsense_web/components/assistant-ui/edit-message-dialog.tsx @@ -0,0 +1,106 @@ +"use client"; + +/** + * Confirmation dialog shown when the user edits a message that has + * reversible downstream actions. Three buttons: + * + * • "Revert all & resubmit" — POST regenerate with revert_actions=true + * • "Continue without revert" — POST regenerate with revert_actions=false + * • "Cancel" — abort the edit entirely + * + * The dialog is auto-skipped when zero reversible downstream actions + * exist (the caller checks first via ``downstreamReversibleCount``). + */ + +import { useEffect, useRef, useState } from "react"; +import { + AlertDialog, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; + +export type EditMessageDialogChoice = "revert" | "continue" | "cancel"; + +export interface EditMessageDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + downstreamReversibleCount: number; + downstreamTotalCount: number; + onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>; +} + +export function EditMessageDialog({ + open, + onOpenChange, + downstreamReversibleCount, + downstreamTotalCount, + onChoose, +}: EditMessageDialogProps) { + const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null); + + // The parent's ``handleEditDialogChoice`` calls + // ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``. + // That collapses the dialog (Radix unmounts it) while ``onChoose`` + // is still awaiting the long-running stream. Without this guard, + // the ``finally { setBusy(null) }`` below ran after unmount and + // produced a "state update on unmounted component" dev warning. + const mountedRef = useRef(true); + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + const handle = async (choice: EditMessageDialogChoice) => { + setBusy(choice); + try { + await onChoose(choice); + } finally { + if (mountedRef.current) { + setBusy(null); + } + } + }; + + return ( + <AlertDialog open={open} onOpenChange={onOpenChange}> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Edit this message?</AlertDialogTitle> + <AlertDialogDescription> + This edit drops {downstreamTotalCount} downstream message + {downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "} + action + {downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can + be rolled back. Pick how to handle them before regenerating. + </AlertDialogDescription> + </AlertDialogHeader> + + <div className="grid gap-2"> + <Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}> + {busy === "revert" + ? "Reverting & resubmitting…" + : `Revert ${downstreamReversibleCount} action${ + downstreamReversibleCount === 1 ? "" : "s" + } & resubmit`} + </Button> + <Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}> + {busy === "continue" ? "Resubmitting…" : "Continue without reverting"} + </Button> + </div> + + <AlertDialogFooter className="sm:justify-start"> + <AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}> + Cancel + </AlertDialogCancel> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + ); +} diff --git a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx new file mode 100644 index 000000000..70636eab8 --- /dev/null +++ b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx @@ -0,0 +1,81 @@ +"use client"; + +import type { ReasoningMessagePartComponent } from "@assistant-ui/react"; +import { ChevronRightIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { cn } from "@/lib/utils"; + +/** + * Renders the structured `reasoning` part emitted by the backend's + * stream-parity v2 path (A1). + * + * Behaviour mirrors the existing `ThinkingStepsDisplay`: + * - collapsed by default; + * - auto-expanded while the part is still `running`; + * - auto-collapsed once status flips to `complete`. + * + * The component is registered via the `Reasoning` slot on + * `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the + * exact ordinal position of the reasoning block in the message content + * array (i.e. above the assistant text that follows it). + */ +export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => { + const isRunning = status?.type === "running"; + const [isOpen, setIsOpen] = useState(() => isRunning); + + useEffect(() => { + if (isRunning) { + setIsOpen(true); + } else if (status?.type === "complete") { + setIsOpen(false); + } + }, [isRunning, status?.type]); + + const headerLabel = useMemo(() => { + if (isRunning) return "Thinking"; + if (status?.type === "incomplete") return "Thinking interrupted"; + return "Thought"; + }, [isRunning, status?.type]); + + if (!text || text.length === 0) { + if (!isRunning) return null; + } + + return ( + <div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2"> + <div className="rounded-lg"> + <button + type="button" + onClick={() => setIsOpen((prev) => !prev)} + className={cn( + "flex w-full items-center gap-1.5 text-left text-sm transition-colors", + "text-muted-foreground hover:text-foreground" + )} + > + {isRunning ? ( + <TextShimmerLoader text={headerLabel} size="sm" /> + ) : ( + <span>{headerLabel}</span> + )} + <ChevronRightIcon + className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")} + /> + </button> + + <div + className={cn( + "grid transition-[grid-template-rows] duration-300 ease-out", + isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]" + )} + > + <div className="overflow-hidden"> + <div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word"> + {text} + </div> + </div> + </div> + </div> + </div> + ); +}; diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx new file mode 100644 index 000000000..9c349738f --- /dev/null +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -0,0 +1,232 @@ +"use client"; + +/** + * "Revert turn" button rendered at the bottom of every completed + * assistant turn that has at least one reversible action. + * + * The button reads the action map keyed by ``chat_turn_id`` from the + * SSE side-channel (``data-action-log`` events). It shows a confirmation + * dialog summarising "N reversible / M total" and, on confirm, calls + * ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + * + * The route returns a per-action result list and never collapses the + * batch into a 4xx — so we render any failed/not_reversible rows inline + * with their messages. + */ + +import { useAtomValue, useSetAtom } from "jotai"; +import { selectAtom } from "jotai/utils"; +import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react"; +import { useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + type AgentActionLite, + agentActionsByChatTurnIdAtom, + markAgentActionsRevertedBatchAtom, +} from "@/atoms/chat/agent-actions.atom"; +import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { + agentActionsApiService, + type RevertTurnActionResult, +} from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; +import { cn } from "@/lib/utils"; + +interface RevertTurnButtonProps { + chatTurnId: string | null | undefined; +} + +function formatToolName(name: string): string { + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + +// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a +// stable reference when the turn has no recorded actions yet. Without +// this every render allocates a fresh ``[]`` and Jotai's +// equality check would re-render the button on unrelated turn updates. +const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]); + +export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { + const session = useAtomValue(chatSessionStateAtom); + const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + const [resultsOpen, setResultsOpen] = useState(false); + const [results, setResults] = useState<RevertTurnActionResult[]>([]); + + // Subscribe ONLY to the slice of the global action map that belongs + // to ``chatTurnId``. Previously the button read the whole + // ``agentActionsByChatTurnIdAtom``, which meant every action + // upsert (one per tool call) re-rendered every Revert button on + // the page. With ``selectAtom`` we re-render only when our turn's + // list reference changes — and the upsert/mark atoms produce a + // fresh list reference for the affected turn only. + const sliceAtom = useMemo( + () => + selectAtom( + agentActionsByChatTurnIdAtom, + (turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS + ), + [chatTurnId] + ); + const actions = useAtomValue(sliceAtom); + + const reversibleCount = useMemo( + () => + actions.filter( + (a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error + ).length, + [actions] + ); + const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); + + if (!chatTurnId) return null; + if (reversibleCount === 0) return null; + const threadId = session?.threadId; + if (!threadId) return null; + + const handleRevertTurn = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revertTurn(threadId, chatTurnId); + setResults(response.results); + const revertedEntries = response.results + .filter((r) => r.status === "reverted" || r.status === "already_reverted") + .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); + if (revertedEntries.length > 0) { + markRevertedBatch({ entries: revertedEntries }); + } + if (response.status === "ok") { + toast.success( + response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.` + ); + } else { + // Every "not undone" bucket counts as a failure for the + // user-facing summary. ``skipped`` rows are batch + // artefacts (revert rows themselves) and intentionally + // excluded from the failure tally. + const failureCount = + response.failed + response.not_reversible + (response.permission_denied ?? 0); + toast.warning( + `Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.` + ); + setResultsOpen(true); + } + } catch (err) { + if (err instanceof AppError && err.status === 503) { + return; + } + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert turn."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( + <> + <AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}> + <AlertDialogTrigger asChild> + <Button + size="sm" + variant="ghost" + className="text-muted-foreground hover:text-foreground gap-1.5" + onClick={(e) => { + e.stopPropagation(); + setConfirmOpen(true); + }} + > + <RotateCcw className="size-3.5" /> + <span>Revert turn</span> + <span className="text-xs tabular-nums opacity-70"> + {reversibleCount}/{totalCount} + </span> + </Button> + </AlertDialogTrigger> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert this turn?</AlertDialogTitle> + <AlertDialogDescription> + This will undo {reversibleCount} of {totalCount} action + {totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and + any read-only actions are preserved. Some rows may not be reversible — partial success + is normal. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={(e) => { + e.preventDefault(); + handleRevertTurn(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert turn"} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + + <AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert results</AlertDialogTitle> + <AlertDialogDescription> + Some actions could not be reverted. Review per-row outcomes below. + </AlertDialogDescription> + </AlertDialogHeader> + <ul className="max-h-72 overflow-y-auto space-y-2 text-sm"> + {results.map((r) => ( + <RevertResultRow key={r.action_id} result={r} /> + ))} + </ul> + <AlertDialogFooter> + <AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + </> + ); +} + +function RevertResultRow({ result }: { result: RevertTurnActionResult }) { + const isOk = result.status === "reverted" || result.status === "already_reverted"; + const Icon = isOk ? CheckIcon : XCircleIcon; + return ( + <li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2"> + <Icon + className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")} + /> + <div className="min-w-0 flex-1"> + <p className="font-medium truncate"> + {formatToolName(result.tool_name)}{" "} + <span className="ml-1 text-xs text-muted-foreground"> + {result.status.replace(/_/g, " ")} + </span> + </p> + {(result.message || result.error) && ( + <p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p> + )} + </div> + </li> + ); +} diff --git a/surfsense_web/components/assistant-ui/step-separator.tsx b/surfsense_web/components/assistant-ui/step-separator.tsx new file mode 100644 index 000000000..f59130661 --- /dev/null +++ b/surfsense_web/components/assistant-ui/step-separator.tsx @@ -0,0 +1,27 @@ +"use client"; + +import { makeAssistantDataUI } from "@assistant-ui/react"; + +/** + * Renders a thin horizontal divider between model steps within a single + * assistant turn. The data part is pushed by `addStepSeparator` in + * `streaming-state.ts` whenever a `start-step` SSE event arrives after + * the message already has non-step content. + * + * Today the backend emits one `start-step` / `finish-step` pair per turn, + * so most messages won't contain a separator. The renderer is wired up so + * the planned per-model-step refactor (A2 follow-up) can light up without + * touching the persistence path. + */ +function StepSeparatorDataRenderer() { + return ( + <div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2"> + <div className="border-t border-border/60" /> + </div> + ); +} + +export const StepSeparatorDataUI = makeAssistantDataUI({ + name: "step-separator", + render: StepSeparatorDataRenderer, +}); diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 112f3e1d8..70eab9ffc 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,12 +1,33 @@ import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react"; +import { useAtomValue, useSetAtom } from "jotai"; +import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + agentActionByToolCallIdAtom, + markAgentActionRevertedAtom, +} from "@/atoms/chat/agent-actions.atom"; +import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { DoomLoopApprovalToolUI, isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; import { cn } from "@/lib/utils"; @@ -14,7 +35,99 @@ function formatToolName(name: string): string { return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); } +/** + * Inline Revert button rendered on a tool card when the matching + * ``AgentActionLog`` row is reversible and hasn't been reverted yet. + * Reads from the SSE side-channel atom keyed by the synthetic + * ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` + * is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). + */ +function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { + const session = useAtomValue(chatSessionStateAtom); + const actionMap = useAtomValue(agentActionByToolCallIdAtom); + const markReverted = useSetAtom(markAgentActionRevertedAtom); + const action = actionMap.get(toolCallId); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + + if (!action) return null; + if (!action.reversible) return null; + if (action.revertedByActionId !== null) return null; + if (action.isRevertAction) return null; + if (action.error) return null; + const threadId = session?.threadId; + if (!threadId) return null; + + const handleRevert = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revert(threadId, action.id); + markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); + toast.success(response.message || "Action reverted."); + } catch (err) { + // 503 means revert is gated off on this deployment — hide the + // button silently rather than nagging the user. Any other error + // is surfaced as a toast so the operator can investigate. + if (err instanceof AppError && err.status === 503) { + return; + } + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert action."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( + <AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}> + <AlertDialogTrigger asChild> + <Button + size="sm" + variant="outline" + className="gap-1.5" + onClick={(e) => { + e.stopPropagation(); + setConfirmOpen(true); + }} + > + <RotateCcw className="size-3.5" /> + Revert + </Button> + </AlertDialogTrigger> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert this action?</AlertDialogTitle> + <AlertDialogDescription> + This will undo <span className="font-medium">{formatToolName(action.toolName)}</span>{" "} + and append a new audit entry. Chat history is preserved — only the tool's effects on + your knowledge base or connectors will be reversed where possible. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={(e) => { + e.preventDefault(); + handleRevert(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert"} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + ); +} + const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ + toolCallId, toolName, argsText, result, @@ -145,6 +258,9 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ </div> </> )} + <div className="flex justify-end"> + <ToolCardRevertButton toolCallId={toolCallId} /> + </div> </div> </> )} diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index deac1fd00..bfdd613e2 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -9,6 +9,7 @@ import { import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile"; import { ShieldCheck } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { createTokenUsageStore, @@ -17,10 +18,13 @@ import { } from "@/components/assistant-ui/token-usage-context"; import { useAnonymousMode } from "@/contexts/anonymous-mode"; import { + addStepSeparator, addToolCall, + appendReasoning, appendText, buildContentForUI, type ContentPartsState, + endReasoning, FrameBatchedUpdater, readSSEStream, type ThinkingStepData, @@ -32,7 +36,9 @@ import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; import { FreeModelSelector } from "./free-model-selector"; import { FreeThread } from "./free-thread"; -const TOOLS_WITH_UI = new Set(["web_search", "document_qna"]); +// Render all tool calls via ToolFallback; backend keeps persisted +// payloads bounded by summarising / truncating outputs. +const TOOLS_WITH_UI = "all" as const; const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? ""; /** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */ @@ -125,6 +131,7 @@ export function FreeChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { toolCallIndices } = contentPartsState; @@ -148,28 +155,62 @@ export function FreeChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + TOOLS_WITH_UI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); break; case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); batcher.flush(); break; @@ -369,6 +410,7 @@ export function FreeChatPage() { <TokenUsageProvider store={tokenUsageStore}> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div className="flex h-full flex-col overflow-hidden"> <div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4"> <FreeModelSelector /> diff --git a/surfsense_web/components/public-chat/public-chat-view.tsx b/surfsense_web/components/public-chat/public-chat-view.tsx index f8dd6db5a..e47ba9bf1 100644 --- a/surfsense_web/components/public-chat/public-chat-view.tsx +++ b/surfsense_web/components/public-chat/public-chat-view.tsx @@ -1,6 +1,7 @@ "use client"; import { AssistantRuntimeProvider } from "@assistant-ui/react"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Navbar } from "@/components/homepage/navbar"; import { ReportPanel } from "@/components/report-panel/report-panel"; @@ -41,6 +42,7 @@ export function PublicChatView({ shareToken }: PublicChatViewProps) { <Navbar scrolledBgClassName={navbarScrolledBg} /> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div className="flex h-screen pt-16 overflow-hidden"> <div className="flex-1 flex flex-col min-w-0 overflow-hidden"> <PublicThread footer={<PublicChatFooter shareToken={shareToken} />} /> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 627baf831..22e914988 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -13,6 +13,7 @@ import Image from "next/image"; import { type FC, type ReactNode, useState } from "react"; import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; +import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { GenerateImageToolUI } from "@/components/tool-ui/generate-image"; @@ -157,6 +158,7 @@ const PublicAssistantMessage: FC = () => { <MessagePrimitive.Parts components={{ Text: MarkdownText, + Reasoning: ReasoningMessagePart, tools: { by_name: { generate_podcast: GeneratePodcastToolUI, diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index bc63bc1b0..1aab08096 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,27 +1,112 @@ import { BookOpen, Brain, + Calendar, + Check, + FileEdit, + FilePlus, FileText, FileUser, + FileX, Film, + FolderPlus, + FolderTree, + FolderX, Globe, ImageIcon, + ListTodo, type LucideIcon, + Mail, + MessagesSquare, + Move, + Plus, Podcast, ScanLine, + Search, + Send, + Trash2, Wrench, } from "lucide-react"; +/** + * Every tool now renders a card via ``ToolFallback``. The icon map is + * keyed on the canonical backend tool name (registered in + * ``surfsense_backend/app/agents/new_chat/tools/registry.py``); unknown + * names fall back to the generic ``Wrench`` icon so the card still + * communicates "this is a tool call". + */ const TOOL_ICONS: Record<string, LucideIcon> = { + // Generators generate_podcast: Podcast, generate_video_presentation: Film, generate_report: FileText, generate_resume: FileUser, generate_image: ImageIcon, + display_image: ImageIcon, + // Web / search scrape_webpage: ScanLine, web_search: Globe, search_surfsense_docs: BookOpen, + // Memory update_memory: Brain, + // Filesystem (built-in deepagent + middleware) + read_file: FileText, + write_file: FilePlus, + edit_file: FileEdit, + move_file: Move, + rm: FileX, + rmdir: FolderX, + mkdir: FolderPlus, + ls: FolderTree, + write_todos: ListTodo, + // Calendar + search_calendar_events: Search, + create_calendar_event: Calendar, + update_calendar_event: Calendar, + delete_calendar_event: Calendar, + // Gmail + search_gmail: Search, + read_gmail_email: Mail, + create_gmail_draft: Mail, + update_gmail_draft: FileEdit, + send_gmail_email: Send, + trash_gmail_email: Trash2, + // Notion / Confluence pages + create_notion_page: FilePlus, + update_notion_page: FileEdit, + delete_notion_page: FileX, + create_confluence_page: FilePlus, + update_confluence_page: FileEdit, + delete_confluence_page: FileX, + // Linear / Jira issues + create_linear_issue: Plus, + update_linear_issue: FileEdit, + delete_linear_issue: Trash2, + create_jira_issue: Plus, + update_jira_issue: FileEdit, + delete_jira_issue: Trash2, + // Drive-like file connectors + create_google_drive_file: FilePlus, + delete_google_drive_file: FileX, + create_dropbox_file: FilePlus, + delete_dropbox_file: FileX, + create_onedrive_file: FilePlus, + delete_onedrive_file: FileX, + // Chat connectors + list_discord_channels: MessagesSquare, + read_discord_messages: MessagesSquare, + send_discord_message: Send, + list_teams_channels: MessagesSquare, + read_teams_messages: MessagesSquare, + send_teams_message: Send, + // Luma + list_luma_events: Calendar, + read_luma_event: Calendar, + create_luma_event: Calendar, + // Misc + get_connected_accounts: Check, + execute: Wrench, + execute_code: Wrench, }; export function getToolIcon(name: string): LucideIcon { diff --git a/surfsense_web/lib/apis/agent-actions-api.service.ts b/surfsense_web/lib/apis/agent-actions-api.service.ts index 007bb131e..6634a11f7 100644 --- a/surfsense_web/lib/apis/agent-actions-api.service.ts +++ b/surfsense_web/lib/apis/agent-actions-api.service.ts @@ -15,6 +15,12 @@ const AgentActionReadSchema = z.object({ reverse_of: z.number().nullable(), reverted_by_action_id: z.number().nullable(), is_revert_action: z.boolean(), + // Correlation ids added in migration 135. The LangChain + // ``tool_call_id`` joins this row to the chat tool card via the + // ``data-action-log.lc_tool_call_id`` SSE event, and + // ``chat_turn_id`` keys the per-turn revert endpoint. + tool_call_id: z.string().nullable().optional(), + chat_turn_id: z.string().nullable().optional(), created_at: z.string(), }); @@ -38,6 +44,48 @@ const RevertResponseSchema = z.object({ export type RevertResponse = z.infer<typeof RevertResponseSchema>; +// Per-turn batch revert. The route never returns whole-batch 4xx; +// partial success is the common case and surfaced as +// ``status === "partial"`` with a per-action result list. +const RevertTurnActionResultSchema = z.object({ + action_id: z.number(), + tool_name: z.string(), + status: z.enum([ + "reverted", + "already_reverted", + "not_reversible", + "permission_denied", + "failed", + "skipped", + ]), + message: z.string().nullable().optional(), + new_action_id: z.number().nullable().optional(), + error: z.string().nullable().optional(), +}); + +export type RevertTurnActionResult = z.infer<typeof RevertTurnActionResultSchema>; + +const RevertTurnResponseSchema = z.object({ + status: z.enum(["ok", "partial"]), + chat_turn_id: z.string(), + total: z.number(), + reverted: z.number(), + already_reverted: z.number(), + not_reversible: z.number(), + // ``permission_denied`` and ``skipped`` are first-class counters so + // ``total === reverted + already_reverted + + // not_reversible + permission_denied + failed + skipped`` always + // holds. ``.default(0)`` keeps the schema backwards-compatible + // with older deployments that haven't shipped the response model + // update yet. + permission_denied: z.number().default(0), + failed: z.number(), + skipped: z.number().default(0), + results: z.array(RevertTurnActionResultSchema), +}); + +export type RevertTurnResponse = z.infer<typeof RevertTurnResponseSchema>; + class AgentActionsApiService { listForThread = async ( threadId: number, @@ -59,6 +107,14 @@ class AgentActionsApiService { { body: {} } ); }; + + revertTurn = async (threadId: number, chatTurnId: string): Promise<RevertTurnResponse> => { + return baseApiService.post( + `/api/v1/threads/${threadId}/revert-turn/${encodeURIComponent(chatTurnId)}`, + RevertTurnResponseSchema, + { body: {} } + ); + }; } export const agentActionsApiService = new AgentActionsApiService(); diff --git a/surfsense_web/lib/chat/message-utils.ts b/surfsense_web/lib/chat/message-utils.ts index 2d1a6976f..004542489 100644 --- a/surfsense_web/lib/chat/message-utils.ts +++ b/surfsense_web/lib/chat/message-utils.ts @@ -40,7 +40,7 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { } const metadata = - msg.author_id || msg.token_usage + msg.author_id || msg.token_usage || msg.turn_id ? { custom: { ...(msg.author_id && { @@ -50,6 +50,10 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { }, }), ...(msg.token_usage && { usage: msg.token_usage }), + // Surface ``chat_turn_id`` so the assistant message + // footer can scope its "Revert turn" button to just + // this turn's actions. Null on legacy rows. + ...(msg.turn_id && { chatTurnId: msg.turn_id }), }, } : undefined; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index ff8fdfbd4..26fd7b98c 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -9,21 +9,42 @@ export interface ThinkingStepData { export type ContentPart = | { type: "text"; text: string } + | { type: "reasoning"; text: string } | { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown; + /** + * Authoritative LangChain ``tool_call.id`` propagated by the backend + * via ``langchainToolCallId`` on tool-input-start/available and + * tool-output-available events. Used to join a card to the + * matching ``AgentActionLog`` row exposed by + * ``GET /threads/{id}/actions`` and the streamed + * ``data-action-log`` events. + */ + langchainToolCallId?: string; } | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] }; + } + | { + /** + * Between-step separator. Pushed by `addStepSeparator` when + * a `start-step` SSE event arrives AFTER the message already + * has non-step content. Rendered by `StepSeparatorDataUI` + * (see assistant-ui/step-separator.tsx). + */ + type: "data-step-separator"; + data: { stepIndex: number }; }; export interface ContentPartsState { contentParts: ContentPart[]; currentTextPartIndex: number; + currentReasoningPartIndex: number; toolCallIndices: Map<string, number>; } @@ -74,6 +95,9 @@ export function updateThinkingSteps( if (state.currentTextPartIndex >= 0) { state.currentTextPartIndex += 1; } + if (state.currentReasoningPartIndex >= 0) { + state.currentReasoningPartIndex += 1; + } for (const [id, idx] of state.toolCallIndices) { state.toolCallIndices.set(id, idx + 1); } @@ -131,6 +155,12 @@ export class FrameBatchedUpdater { } export function appendText(state: ContentPartsState, delta: string): void { + // First text delta after a reasoning block: close the reasoning so + // the assistant-ui renderer treats them as separate parts (the + // reasoning block collapses; the answer streams below). + if (state.currentReasoningPartIndex >= 0) { + state.currentReasoningPartIndex = -1; + } if ( state.currentTextPartIndex >= 0 && state.contentParts[state.currentTextPartIndex]?.type === "text" @@ -143,36 +173,129 @@ export function appendText(state: ContentPartsState, delta: string): void { } } +export function appendReasoning(state: ContentPartsState, delta: string): void { + // Symmetric to appendText: open a fresh reasoning block on first + // delta, then accumulate into it. ``endReasoning`` simply closes + // the active block; subsequent reasoning deltas would open a new + // one (matching ``text-start/end`` semantics on the wire). + if (state.currentTextPartIndex >= 0) { + state.currentTextPartIndex = -1; + } + if ( + state.currentReasoningPartIndex >= 0 && + state.contentParts[state.currentReasoningPartIndex]?.type === "reasoning" + ) { + ( + state.contentParts[state.currentReasoningPartIndex] as { + type: "reasoning"; + text: string; + } + ).text += delta; + } else { + state.contentParts.push({ type: "reasoning", text: delta }); + state.currentReasoningPartIndex = state.contentParts.length - 1; + } +} + +export function endReasoning(state: ContentPartsState): void { + state.currentReasoningPartIndex = -1; +} + +export function addStepSeparator(state: ContentPartsState): void { + // Push a divider between consecutive model steps within a single + // assistant turn. We only emit it when the message already has + // non-step content (so the FIRST step of a turn doesn't + // generate a leading separator) and when the previous part isn't + // itself a separator (defensive against duplicate `start-step` + // events). + const hasContent = state.contentParts.some( + (p) => p.type === "text" || p.type === "reasoning" || p.type === "tool-call" + ); + if (!hasContent) return; + const last = state.contentParts[state.contentParts.length - 1]; + if (last && last.type === "data-step-separator") return; + + const stepIndex = state.contentParts.filter((p) => p.type === "data-step-separator").length; + state.contentParts.push({ type: "data-step-separator", data: { stepIndex } }); + state.currentTextPartIndex = -1; + state.currentReasoningPartIndex = -1; +} + +/** + * Allowlist of tool names that should produce a UI tool card. The + * sentinel ``"all"`` matches every tool — we dropped the legacy + * ``BASE_TOOLS_WITH_UI`` gate so that ALL tool calls render via the + * generic ``ToolFallback``. The backend's ``format_thinking_step`` + * summarisation and the defensive ``result_length``-only default for + * unknown tools keep persisted message JSON from ballooning. + */ +export type ToolUIGate = Set<string> | "all"; + +function _toolPasses(gate: ToolUIGate, toolName: string): boolean { + return gate === "all" || gate.has(toolName); +} + export function addToolCall( state: ContentPartsState, - toolsWithUI: Set<string>, + toolsWithUI: ToolUIGate, toolCallId: string, toolName: string, args: Record<string, unknown>, - force = false + force = false, + langchainToolCallId?: string ): void { - if (force || toolsWithUI.has(toolName)) { + if (force || _toolPasses(toolsWithUI, toolName)) { state.contentParts.push({ type: "tool-call", toolCallId, toolName, args, + ...(langchainToolCallId ? { langchainToolCallId } : {}), }); state.toolCallIndices.set(toolCallId, state.contentParts.length - 1); state.currentTextPartIndex = -1; + state.currentReasoningPartIndex = -1; } } +/** + * Reverse-lookup helper used by the SSE ``data-action-log`` handler: + * given the LangChain ``tool_call.id`` (set on the content part as + * ``langchainToolCallId``), return the synthetic ``toolCallId`` that + * the chat tool card uses (``call_<run-id>``). Returns ``null`` when no + * matching tool card has been seen yet — the action is still recorded + * in the LC-id-keyed atom so the card can pick it up when it eventually + * arrives. + */ +export function findToolCallIdByLcId( + state: ContentPartsState, + lcToolCallId: string +): string | null { + for (const part of state.contentParts) { + if (part.type === "tool-call" && part.langchainToolCallId === lcToolCallId) { + return part.toolCallId; + } + } + return null; +} + export function updateToolCall( state: ContentPartsState, toolCallId: string, - update: { args?: Record<string, unknown>; result?: unknown } + update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string } ): void { const index = state.toolCallIndices.get(toolCallId); if (index !== undefined && state.contentParts[index]?.type === "tool-call") { const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; if (update.args) tc.args = update.args; if (update.result !== undefined) tc.result = update.result; + // Only backfill langchainToolCallId if not already set — the + // authoritative ``on_tool_end`` value should override an earlier + // best-effort match, but a NULL late-arriving value should not + // blow away a known good early one. + if (update.langchainToolCallId && !tc.langchainToolCallId) { + tc.langchainToolCallId = update.langchainToolCallId; + } } } @@ -184,13 +307,15 @@ function _hasInterruptResult(part: ContentPart): boolean { export function buildContentForUI( state: ContentPartsState, - toolsWithUI: Set<string> + toolsWithUI: ToolUIGate ): ThreadMessageLike["content"] { const filtered = state.contentParts.filter((part) => { if (part.type === "text") return part.text.length > 0; + if (part.type === "reasoning") return part.text.length > 0; if (part.type === "tool-call") - return toolsWithUI.has(part.toolName) || _hasInterruptResult(part); + return _toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part); if (part.type === "data-thinking-steps") return true; + if (part.type === "data-step-separator") return true; return false; }); return filtered.length > 0 @@ -200,20 +325,28 @@ export function buildContentForUI( export function buildContentForPersistence( state: ContentPartsState, - toolsWithUI: Set<string> + toolsWithUI: ToolUIGate ): unknown[] { const parts: unknown[] = []; for (const part of state.contentParts) { if (part.type === "text" && part.text.length > 0) { parts.push(part); + } else if (part.type === "reasoning" && part.text.length > 0) { + // Persist reasoning blocks so a chat reload re-renders the + // collapsed thinking section instead of + // silently dropping it (mirrors the data-thinking-steps + // branch above). + parts.push(part); } else if ( part.type === "tool-call" && - (toolsWithUI.has(part.toolName) || _hasInterruptResult(part)) + (_toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part)) ) { parts.push(part); } else if (part.type === "data-thinking-steps") { parts.push(part); + } else if (part.type === "data-step-separator") { + parts.push(part); } } @@ -221,23 +354,122 @@ export function buildContentForPersistence( } export type SSEEvent = - | { type: "text-delta"; delta: string } - | { type: "tool-input-start"; toolCallId: string; toolName: string } + | { type: "start"; messageId?: string } + | { type: "finish" } + | { type: "start-step" } + | { type: "finish-step" } + | { type: "text-start"; id: string } + | { type: "text-delta"; id?: string; delta: string } + | { type: "text-end"; id: string } + | { type: "reasoning-start"; id: string } + | { type: "reasoning-delta"; id?: string; delta: string } + | { type: "reasoning-end"; id: string } + | { + type: "tool-input-start"; + toolCallId: string; + toolName: string; + /** Authoritative LangChain ``tool_call.id``. Optional. */ + langchainToolCallId?: string; + } | { type: "tool-input-available"; toolCallId: string; toolName: string; input: Record<string, unknown>; + langchainToolCallId?: string; } | { type: "tool-output-available"; toolCallId: string; output: Record<string, unknown>; + /** Authoritative LangChain ``tool_call.id`` extracted from + * ``ToolMessage.tool_call_id`` at on_tool_end. Backfills cards + * that didn't get the id at tool-input-start time. */ + langchainToolCallId?: string; } | { type: "data-thinking-step"; data: ThinkingStepData } | { type: "data-thread-title-update"; data: { threadId: number; title: string } } | { type: "data-interrupt-request"; data: Record<string, unknown> } | { type: "data-documents-updated"; data: Record<string, unknown> } + | { + /** + * A freshly committed AgentActionLog row. Frontend stores + * this in a Map keyed off ``lc_tool_call_id`` so the chat + * tool card can light up its Revert button. + */ + type: "data-action-log"; + data: { + id: number; + lc_tool_call_id: string | null; + chat_turn_id: string | null; + tool_name: string; + reversible: boolean; + reverse_descriptor_present: boolean; + created_at: string | null; + error: boolean; + }; + } + | { + /** + * Reversibility flipped (filesystem op SAVEPOINT committed; + * cf. ``kb_persistence._dispatch_reversibility_update``). + */ + type: "data-action-log-updated"; + data: { id: number; reversible: boolean }; + } + | { + /** + * Emitted at the start of every stream so the frontend can + * stamp the per-turn correlation id onto the in-flight + * assistant message and replay it via + * ``appendMessage``. Pure-text turns never produce + * action-log events; this event guarantees the frontend + * always learns the turn id. + */ + type: "data-turn-info"; + data: { chat_turn_id: string }; + } + | { + /** + * Best-effort revert pass that ran BEFORE this regeneration. + * Per-action results are forwarded to the UI so the user + * can see which downstream actions were rolled + * back vs which couldn't be undone. + */ + type: "data-revert-results"; + data: { + status: "ok" | "partial"; + chat_turn_ids: string[]; + total: number; + reverted: number; + already_reverted: number; + not_reversible: number; + /** + * ``permission_denied`` and ``skipped`` are first-class + * counters so the response invariant + * ``total === sum(counters)`` always holds. Optional + * for forward compatibility with older backends; the + * frontend treats missing values as ``0``. + */ + permission_denied?: number; + failed: number; + skipped?: number; + results: Array<{ + action_id: number; + tool_name: string; + status: + | "reverted" + | "already_reverted" + | "not_reversible" + | "permission_denied" + | "failed" + | "skipped"; + message?: string | null; + new_action_id?: number | null; + error?: string | null; + }>; + }; + } | { type: "data-token-usage"; data: { diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index b5c5899b4..fc970c26e 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -46,6 +46,11 @@ export interface MessageRecord { author_display_name?: string | null; author_avatar_url?: string | null; token_usage?: TokenUsageSummary | null; + // Per-turn correlation id from ``configurable.turn_id`` at streaming + // time (added in migration 136). Used by the per-turn revert + // endpoint and edit-from-arbitrary-position. Nullable on legacy + // rows that predate the column. + turn_id?: string | null; } export interface ThreadListResponse { @@ -123,10 +128,20 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory /** * Append a message to a thread. + * + * ``turn_id`` is the per-turn correlation id streamed by the backend + * via ``data-turn-info``. Persisting it lets later edits locate the + * matching LangGraph checkpoint without HumanMessage scanning. Older + * callers can still omit it for back-compat. */ export async function appendMessage( threadId: number, - message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown } + message: { + role: "user" | "assistant" | "system"; + content: unknown; + token_usage?: unknown; + turn_id?: string | null; + } ): Promise<MessageRecord> { return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, { body: message, From 9a114a2d45f0341c0edbfeac1dafba858aa7e38e Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Wed, 29 Apr 2026 07:40:11 -0700 Subject: [PATCH 231/299] feat: enhance tool display names for better user experience in chat UI --- .../app/tasks/chat/stream_new_chat.py | 141 +++++++++++++++++- .../agent-action-log/action-log-item.tsx | 8 +- .../assistant-ui/revert-turn-button.tsx | 7 +- .../components/assistant-ui/thread.tsx | 13 +- .../components/assistant-ui/tool-fallback.tsx | 19 +-- .../tool-ui/generic-hitl-approval.tsx | 6 +- surfsense_web/contracts/enums/toolIcons.tsx | 105 +++++++++++++ 7 files changed, 267 insertions(+), 32 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2f8e33ba9..f7bf75649 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -622,6 +622,95 @@ async def _stream_agent_events( status="in_progress", items=last_active_step_items, ) + elif tool_name == "rm": + rm_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] + last_active_step_title = "Deleting file" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Deleting file", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "rmdir": + rmdir_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = ( + rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:] + ) + last_active_step_title = "Deleting folder" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Deleting folder", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "mkdir": + mkdir_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = ( + mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:] + ) + last_active_step_title = "Creating folder" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Creating folder", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "move_file": + src = ( + tool_input.get("source_path", "") + if isinstance(tool_input, dict) + else "" + ) + dst = ( + tool_input.get("destination_path", "") + if isinstance(tool_input, dict) + else "" + ) + display_src = src if len(src) <= 60 else "…" + src[-57:] + display_dst = dst if len(dst) <= 60 else "…" + dst[-57:] + last_active_step_title = "Moving file" + last_active_step_items = ( + [f"{display_src} → {display_dst}"] if src or dst else [] + ) + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Moving file", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "write_todos": + todos = ( + tool_input.get("todos", []) if isinstance(tool_input, dict) else [] + ) + todo_count = len(todos) if isinstance(todos, list) else 0 + last_active_step_title = "Planning tasks" + last_active_step_items = ( + [f"{todo_count} task{'s' if todo_count != 1 else ''}"] + if todo_count + else [] + ) + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Planning tasks", + status="in_progress", + items=last_active_step_items, + ) elif tool_name == "save_document": doc_title = ( tool_input.get("title", "") @@ -729,7 +818,15 @@ async def _stream_agent_events( items=last_active_step_items, ) else: - last_active_step_title = f"Using {tool_name.replace('_', ' ')}" + # Fallback for tools without a curated thinking-step title + # (typically connector tools, MCP-registered tools, or + # newly added tools that haven't been wired up here yet). + # Render the snake_cased name as a sentence-cased phrase + # so non-technical users see e.g. "Send gmail email" + # rather than the raw identifier "send_gmail_email". + last_active_step_title = ( + tool_name.replace("_", " ").strip().capitalize() or tool_name + ) last_active_step_items = [] yield streaming_service.format_thinking_step( step_id=tool_step_id, @@ -885,6 +982,41 @@ async def _stream_agent_events( status="completed", items=last_active_step_items, ) + elif tool_name == "rm": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Deleting file", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "rmdir": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Deleting folder", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "mkdir": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Creating folder", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "move_file": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Moving file", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "write_todos": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Planning tasks", + status="completed", + items=last_active_step_items, + ) elif tool_name == "save_document": result_str = ( tool_output.get("result", "") @@ -1136,9 +1268,14 @@ async def _stream_agent_events( items=completed_items, ) else: + # Fallback completion title — see the matching in-progress + # branch above for the wording rationale. + fallback_title = ( + tool_name.replace("_", " ").strip().capitalize() or tool_name + ) yield streaming_service.format_thinking_step( step_id=original_step_id, - title=f"Using {tool_name.replace('_', ' ')}", + title=fallback_title, status="completed", items=last_active_step_items, ) diff --git a/surfsense_web/components/agent-action-log/action-log-item.tsx b/surfsense_web/components/agent-action-log/action-log-item.tsx index 425714c1f..673189709 100644 --- a/surfsense_web/components/agent-action-log/action-log-item.tsx +++ b/surfsense_web/components/agent-action-log/action-log-item.tsx @@ -17,16 +17,12 @@ import { import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; -import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - interface ActionLogItemProps { action: AgentAction; threadId: number; @@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt const hasError = action.error !== null && action.error !== undefined; const Icon = getToolIcon(action.tool_name); - const displayName = formatToolName(action.tool_name); + const displayName = getToolDisplayName(action.tool_name); const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null; const truncatedArgs = diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx index 9c349738f..af71299d0 100644 --- a/surfsense_web/components/assistant-ui/revert-turn-button.tsx +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -37,6 +37,7 @@ import { AlertDialogTrigger, } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; import { agentActionsApiService, type RevertTurnActionResult, @@ -48,10 +49,6 @@ interface RevertTurnButtonProps { chatTurnId: string | null | undefined; } -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - // Empty-array sentinel so the per-turn ``selectAtom`` slice returns a // stable reference when the turn has no recorded actions yet. Without // this every render allocates a fresh ``[]`` and Jotai's @@ -218,7 +215,7 @@ function RevertResultRow({ result }: { result: RevertTurnActionResult }) { /> <div className="min-w-0 flex-1"> <p className="font-medium truncate"> - {formatToolName(result.tool_name)}{" "} + {getToolDisplayName(result.tool_name)}{" "} <span className="ml-1 text-xs text-muted-foreground"> {result.status.replace(/_/g, " ")} </span> diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cf99598f1..e58783c87 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -82,6 +82,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { CONNECTOR_ICON_TO_TYPES, CONNECTOR_TOOL_ICON_PATHS, + getToolDisplayName, getToolIcon, } from "@/contracts/enums/toolIcons"; import type { Document } from "@/contracts/types/document.types"; @@ -1317,12 +1318,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false ); }; -/** Convert snake_case tool names to human-readable labels */ +/** + * Friendly tool name for display in the chat UI. Delegates to the + * shared map in ``contracts/enums/toolIcons`` so unix-style identifiers + * (``rm``, ``ls``, ``grep`` …) and snake_cased function names render as + * plain English (e.g. "Delete file", "List files", "Search in files"). + */ function formatToolName(name: string): string { - return name - .split("_") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) - .join(" "); + return getToolDisplayName(name); } interface ToolGroup { diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 70eab9ffc..cc7582695 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -25,16 +25,12 @@ import { AlertDialogTrigger, } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; -import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; import { cn } from "@/lib/utils"; -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - /** * Inline Revert button rendered on a tool card when the matching * ``AgentActionLog`` row is reversible and hasn't been reverted yet. @@ -104,9 +100,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { <AlertDialogHeader> <AlertDialogTitle>Revert this action?</AlertDialogTitle> <AlertDialogDescription> - This will undo <span className="font-medium">{formatToolName(action.toolName)}</span>{" "} - and append a new audit entry. Chat history is preserved — only the tool's effects on - your knowledge base or connectors will be reversed where possible. + This will undo{" "} + <span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a + new entry to the history. Your chat is preserved — only the changes the agent made to + your knowledge base or connected apps will be rolled back where possible. </AlertDialogDescription> </AlertDialogHeader> <AlertDialogFooter> @@ -164,7 +161,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ : null; const Icon = getToolIcon(toolName); - const displayName = formatToolName(toolName); + const displayName = getToolDisplayName(toolName); return ( <div @@ -215,7 +212,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ ? `Failed: ${displayName}` : displayName} </p> - {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>} + {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working…</p>} {cancelledReason && ( <p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p> )} @@ -241,7 +238,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ <div className="px-5 py-3 space-y-3"> {argsText && ( <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p> + <p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p> <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> {argsText} </pre> diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index ceb1d0209..a584084ff 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -8,6 +8,7 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Textarea } from "@/components/ui/textarea"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; import { connectorsApiService } from "@/lib/apis/connectors-api.service"; import type { HitlDecision, InterruptResult } from "@/lib/hitl"; @@ -77,7 +78,7 @@ function GenericApprovalCard({ const [editedParams, setEditedParams] = useState<Record<string, unknown>>(args); const [isEditing, setIsEditing] = useState(false); - const displayName = toolName.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); + const displayName = getToolDisplayName(toolName); const mcpServer = interruptData.context?.mcp_server as string | undefined; const toolDescription = interruptData.context?.tool_description as string | undefined; @@ -186,12 +187,11 @@ function GenericApprovalCard({ </> )} - {/* Parameters */} {Object.keys(args).length > 0 && ( <> <div className="mx-5 h-px bg-border/50" /> <div className="px-5 py-4 space-y-2"> - <p className="text-xs font-medium text-muted-foreground">Parameters</p> + <p className="text-xs font-medium text-muted-foreground">Inputs</p> {phase === "pending" && isEditing ? ( <ParamEditor params={editedParams} diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 1aab08096..bdb8222cb 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -113,6 +113,111 @@ export function getToolIcon(name: string): LucideIcon { return TOOL_ICONS[name] ?? Wrench; } +/** + * Friendly display names for tools shown in the chat UI. + * + * Most users aren't engineers; they shouldn't see raw unix-style + * identifiers like ``rm`` / ``rmdir`` / ``ls`` / ``grep`` / ``glob`` or + * snake_cased function names. The map below renders each tool with + * plain English wording (verb + object) so non-technical users + * understand what the agent is doing at a glance. + * + * Unmapped tool names fall back to a snake_case-to-Title-Case + * conversion via :func:`getToolDisplayName`. + */ +const TOOL_DISPLAY_NAMES: Record<string, string> = { + // Filesystem / knowledge base + read_file: "Read file", + write_file: "Write file", + edit_file: "Edit file", + move_file: "Move file", + rm: "Delete file", + rmdir: "Delete folder", + mkdir: "Create folder", + ls: "List files", + glob: "Find files", + grep: "Search in files", + write_todos: "Plan tasks", + save_document: "Save document", + // Generators + generate_podcast: "Generate podcast", + generate_video_presentation: "Generate video presentation", + generate_report: "Generate report", + generate_resume: "Generate resume", + generate_image: "Generate image", + display_image: "Show image", + // Web / search + scrape_webpage: "Read webpage", + web_search: "Search the web", + search_surfsense_docs: "Search knowledge base", + // Memory + update_memory: "Update memory", + // Calendar + search_calendar_events: "Search calendar", + create_calendar_event: "Create event", + update_calendar_event: "Update event", + delete_calendar_event: "Delete event", + // Gmail + search_gmail: "Search Gmail", + read_gmail_email: "Read email", + create_gmail_draft: "Draft email", + update_gmail_draft: "Update draft", + send_gmail_email: "Send email", + trash_gmail_email: "Move email to trash", + // Notion + create_notion_page: "Create Notion page", + update_notion_page: "Update Notion page", + delete_notion_page: "Delete Notion page", + // Confluence + create_confluence_page: "Create Confluence page", + update_confluence_page: "Update Confluence page", + delete_confluence_page: "Delete Confluence page", + // Linear + create_linear_issue: "Create Linear issue", + update_linear_issue: "Update Linear issue", + delete_linear_issue: "Delete Linear issue", + // Jira + create_jira_issue: "Create Jira issue", + update_jira_issue: "Update Jira issue", + delete_jira_issue: "Delete Jira issue", + // Drive-like file connectors + create_google_drive_file: "Create Google Drive file", + delete_google_drive_file: "Delete Google Drive file", + create_dropbox_file: "Create Dropbox file", + delete_dropbox_file: "Delete Dropbox file", + create_onedrive_file: "Create OneDrive file", + delete_onedrive_file: "Delete OneDrive file", + // Discord + list_discord_channels: "List Discord channels", + read_discord_messages: "Read Discord messages", + send_discord_message: "Send Discord message", + // Teams + list_teams_channels: "List Teams channels", + read_teams_messages: "Read Teams messages", + send_teams_message: "Send Teams message", + // Luma + list_luma_events: "List Luma events", + read_luma_event: "Read Luma event", + create_luma_event: "Create Luma event", + // Misc + get_connected_accounts: "Check connected accounts", + execute: "Run command", + execute_code: "Run code", +}; + +/** + * Format a tool's canonical (snake_case) name for display in the chat UI. + * + * Looks up :data:`TOOL_DISPLAY_NAMES` first; falls back to a + * snake_case-to-Title-Case rewrite for tools that don't have a curated + * label (e.g. dynamically registered MCP tools). + */ +export function getToolDisplayName(name: string): string { + const friendly = TOOL_DISPLAY_NAMES[name]; + if (friendly) return friendly; + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + export const CONNECTOR_TOOL_ICON_PATHS: Record<string, { src: string; alt: string }> = { gmail: { src: "/connectors/google-gmail.svg", alt: "Gmail" }, google_calendar: { src: "/connectors/google-calendar.svg", alt: "Google Calendar" }, From c598d7038f4f2766e37b9e9dc3e037b07fc1938b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:17:45 +0530 Subject: [PATCH 232/299] refactor(chat): update premium token error messages for clarity and consistency --- .../app/tasks/chat/stream_new_chat.py | 4 ++-- .../new-chat/[[...chat_id]]/page.tsx | 16 ++++++---------- surfsense_web/components/assistant-ui/thread.tsx | 7 +++---- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1a56547ca..233b45396 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1542,7 +1542,7 @@ async def stream_new_chat( llm_config_id, ) yield streaming_service.format_error( - "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." ) yield streaming_service.format_done() return @@ -2263,7 +2263,7 @@ async def stream_resume_chat( llm_config_id, ) yield streaming_service.format_error( - "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." ) yield streaming_service.format_done() return diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index a5461e17f..05621419d 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -201,17 +201,16 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); -const PINNED_PREMIUM_QUOTA_MESSAGE = "Premium token quota exceeded for this pinned model."; - function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { if (!(error instanceof Error)) return null; - if (!error.message.toLowerCase().includes("premium token quota exceeded")) { + const normalized = error.message.toLowerCase(); + if ( + !normalized.includes("premium tokens exhausted") + && !normalized.includes("premium token quota exceeded") + ) { return null; } - if (!error.message.toLowerCase().includes("pinned model")) { - return null; - } - return error.message || PINNED_PREMIUM_QUOTA_MESSAGE; + return error.message; } export default function NewChatPage() { @@ -980,7 +979,6 @@ export default function NewChatPage() { threadId: currentThreadId, message: premiumQuotaAlertMessage, }); - toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); } else { toast.error("Failed to get response. Please try again."); } @@ -1290,7 +1288,6 @@ export default function NewChatPage() { threadId: resumeThreadId, message: premiumQuotaAlertMessage, }); - toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); } else { toast.error("Failed to resume. Please try again."); } @@ -1638,7 +1635,6 @@ export default function NewChatPage() { threadId, message: premiumQuotaAlertMessage, }); - toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); } else { toast.error("Failed to regenerate response. Please try again."); } diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 06f25f5fb..cb063fac3 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -161,16 +161,15 @@ const PremiumQuotaPinnedAlert: FC = () => { if (!alert) return null; return ( - <div className="mx-2 rounded-2xl border border-amber-300/40 bg-amber-500/10 px-4 py-3 text-amber-50 shadow-lg backdrop-blur-sm"> + <div className="mx-0 bg-amber-500/10 px-3 py-2 text-amber-100"> <div className="flex items-start gap-2"> <AlertCircle className="mt-0.5 size-4 shrink-0 text-amber-300" /> <div className="min-w-0 flex-1"> - <p className="text-sm font-medium">Premium quota exhausted</p> - <p className="mt-1 text-xs text-amber-100/90">{alert.message}</p> + <p className="text-sm">{alert.message}</p> </div> <button type="button" - className="inline-flex size-6 items-center justify-center rounded-md text-amber-200 transition-colors hover:bg-amber-200/20 hover:text-amber-50" + className="inline-flex size-6 items-center justify-center text-amber-200 transition-colors hover:text-amber-50" aria-label="Dismiss premium quota alert" onClick={() => clearPremiumAlertForThread(currentThreadId)} > From d66fa1559b3913648e195c379e60b03ff1f00baf Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:29:41 +0530 Subject: [PATCH 233/299] feat(chat): implement forced repin to free tier for pinned LLM configurations --- .../app/services/auto_model_pin_service.py | 17 +- .../app/tasks/chat/stream_new_chat.py | 209 ++++++++++++------ .../services/test_auto_model_pin_service.py | 38 ++++ 3 files changed, 200 insertions(+), 64 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index ce417a26d..6bdb60f57 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -84,6 +84,7 @@ async def resolve_or_get_pinned_llm_config_id( search_space_id: int, user_id: str | UUID | None, selected_llm_config_id: int, + force_repin_free: bool = False, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. @@ -130,9 +131,12 @@ async def resolve_or_get_pinned_llm_config_id( raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} - # Reuse existing valid pin without re-checking current quota (no silent tier switch). + # Reuse existing valid pin without re-checking current quota (no silent tier switch), + # unless the caller explicitly requests a forced repin to free. pinned_id = thread.pinned_llm_config_id if ( + not force_repin_free + and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id @@ -159,7 +163,7 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_auto_mode, ) - premium_eligible = await _is_premium_eligible(session, user_id) + premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id) if premium_eligible: eligible = candidates else: @@ -179,6 +183,15 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_at = datetime.now(UTC) await session.commit() + if force_repin_free: + logger.info( + "auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + ) + if pinned_id is None: logger.info( "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 233b45396..edc5aa763 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1455,6 +1455,37 @@ async def stream_new_chat( await set_ai_responding(session, chat_id, UUID(user_id)) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) _t0 = time.perf_counter() try: @@ -1472,35 +1503,11 @@ async def stream_new_chat( yield streaming_service.format_done() return - if llm_config_id >= 0: - # Positive ID: Load from NewLLMConfig database table - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, - ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from AgentConfig - llm = create_chat_litellm_from_agent_config(agent_config) - else: - # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from global config dict - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return _perf_log.info( "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", time.perf_counter() - _t0, @@ -1541,11 +1548,43 @@ async def stream_new_chat( user_id, llm_config_id, ) - yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." - ) - yield streaming_service.format_done() - return + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return + _premium_request_id = None + _premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + else: + yield streaming_service.format_error( + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -2183,6 +2222,38 @@ async def stream_resume_chat( await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) + _t0 = time.perf_counter() try: llm_config_id = ( @@ -2199,29 +2270,11 @@ async def stream_resume_chat( yield streaming_service.format_done() return - if llm_config_id >= 0: - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, - ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_agent_config(agent_config) - else: - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return _perf_log.info( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) @@ -2262,11 +2315,43 @@ async def stream_resume_chat( user_id, llm_config_id, ) - yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." - ) - yield streaming_service.format_done() - return + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return + _resume_premium_request_id = None + _resume_premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + else: + yield streaming_service.format_error( + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index a9853c980..f08e50ba2 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -227,6 +227,44 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): assert result.from_existing_pin is True +@pytest.mark.asyncio +async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + force_repin_free=True, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + @pytest.mark.asyncio async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config From a68889511569caa045dec9420353b5e8a9a16647 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Wed, 29 Apr 2026 08:03:39 -0700 Subject: [PATCH 234/299] feat: increase recursion limit for chat streaming to enhance tool iteration capabilities --- .../app/tasks/chat/stream_new_chat.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index f7bf75649..1493c4326 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2090,7 +2090,16 @@ async def stream_new_chat( config = { "configurable": configurable, - "recursion_limit": 80, # Increase from default 25 to allow more tool iterations + # Effectively uncapped, matching the agent-level + # ``with_config`` default in ``chat_deepagent.create_agent`` + # and the unbounded ``while(true)`` loop used by OpenCode's + # ``session/processor.ts``. Real circuit-breakers live in + # middleware: ``DoomLoopMiddleware`` (sliding-window tool + # signature check), plus ``enable_tool_call_limit`` / + # ``enable_model_call_limit`` when those flags are set. The + # original LangGraph default of 25 (and our previous 80 + # bump) hit users on legitimate multi-tool plans. + "recursion_limit": 10_000, } # Start the message stream @@ -2686,7 +2695,11 @@ async def stream_resume_chat( "request_id": request_id or "unknown", "turn_id": stream_result.turn_id, }, - "recursion_limit": 80, + # See ``stream_new_chat`` above for rationale: effectively + # uncapped to mirror the agent default and OpenCode's + # session loop. Doom-loop / call-limit middleware enforce + # the real ceiling. + "recursion_limit": 10_000, } yield streaming_service.format_message_start() From fa6a09197ef51641a649d14b14dd68cc131fddbd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:57:33 +0530 Subject: [PATCH 235/299] feat(chat): enhance error handling for premium quota exhaustion in chat messages --- .../new-chat/[[...chat_id]]/page.tsx | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 05621419d..ed0611ee9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -201,6 +201,9 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); +const PREMIUM_QUOTA_ASSISTANT_MESSAGE = + "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { if (!(error instanceof Error)) return null; const normalized = error.message.toLowerCase(); @@ -992,7 +995,9 @@ export default function NewChatPage() { { type: "text", text: - premiumQuotaAlertMessage ?? + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", }, ], @@ -1288,6 +1293,16 @@ export default function NewChatPage() { threadId: resumeThreadId, message: premiumQuotaAlertMessage, }); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [{ type: "text", text: PREMIUM_QUOTA_ASSISTANT_MESSAGE }], + } + : m + ) + ); } else { toast.error("Failed to resume. Please try again."); } @@ -1647,7 +1662,9 @@ export default function NewChatPage() { { type: "text", text: - premiumQuotaAlertMessage ?? + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", }, ], From 901de3368402d7545ea2572e617c063b357429a2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 21:05:21 +0530 Subject: [PATCH 236/299] feat(chat): enhance error formatting to include optional error codes for better frontend handling --- .../app/services/new_streaming_service.py | 10 +- .../app/tasks/chat/stream_new_chat.py | 6 +- .../new-chat/[[...chat_id]]/page.tsx | 137 +++++++++++------- surfsense_web/lib/chat/streaming-state.ts | 2 +- 4 files changed, 97 insertions(+), 58 deletions(-) diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 52a215997..3e24c1376 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,20 +565,24 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str) -> str: + def format_error(self, error_text: str, error_code: str | None = None) -> str: """ Format an error message. Args: error_text: The error message text + error_code: Optional machine-readable error code for frontend branching Returns: str: SSE formatted error part Example output: - data: {"type":"error","errorText":"Something went wrong"} + data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - return self._format_sse({"type": "error", "errorText": error_text}) + payload: dict[str, str] = {"type": "error", "errorText": error_text} + if error_code: + payload["errorCode"] = error_code + return self._format_sse(payload) # ========================================================================= # Tool Parts diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index edc5aa763..060dd23c6 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1581,7 +1581,8 @@ async def stream_new_chat( ) else: yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + "Buy more tokens to continue with this model, or switch to a free model.", + error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() return @@ -2348,7 +2349,8 @@ async def stream_resume_chat( ) else: yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + "Buy more tokens to continue with this model, or switch to a free model.", + error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() return diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index ed0611ee9..f775e1f06 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -206,10 +206,15 @@ const PREMIUM_QUOTA_ASSISTANT_MESSAGE = function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { if (!(error instanceof Error)) return null; + const withCode = error as Error & { errorCode?: string }; + if (withCode.errorCode === "PREMIUM_QUOTA_EXHAUSTED") { + return error.message; + } const normalized = error.message.toLowerCase(); if ( !normalized.includes("premium tokens exhausted") && !normalized.includes("premium token quota exceeded") + && !normalized.includes("buy more tokens") ) { return null; } @@ -233,6 +238,50 @@ export default function NewChatPage() { } | null>(null); const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const persistAssistantErrorMessage = useCallback( + async ({ + threadId, + assistantMsgId, + text, + }: { + threadId: number | null; + assistantMsgId: string; + text: string; + }) => { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [{ type: "text", text }], + } + : m + ) + ); + + if (!threadId) return; + + // Persist only temporary assistant placeholders to avoid duplicate rows + // when the message already has a database-backed ID. + if (!assistantMsgId.startsWith("msg-assistant-")) return; + + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: [{ type: "text", text }], + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + } catch (persistErr) { + console.error("Failed to persist assistant error message:", persistErr); + } + }, + [tokenUsageStore] + ); + // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -903,7 +952,9 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } @@ -985,26 +1036,14 @@ export default function NewChatPage() { } else { toast.error("Failed to get response. Please try again."); } - // Update assistant message with error - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [ - { - type: "text", - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? - "Sorry, there was an error. Please try again.", - }, - ], - } - : m - ) - ); + await persistAssistantErrorMessage({ + threadId: currentThreadId, + assistantMsgId, + text: + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1028,6 +1067,7 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, setPremiumAlertForThread, + persistAssistantErrorMessage, ] ); @@ -1258,7 +1298,9 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } @@ -1293,19 +1335,17 @@ export default function NewChatPage() { threadId: resumeThreadId, message: premiumQuotaAlertMessage, }); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [{ type: "text", text: PREMIUM_QUOTA_ASSISTANT_MESSAGE }], - } - : m - ) - ); } else { toast.error("Failed to resume. Please try again."); } + await persistAssistantErrorMessage({ + threadId: resumeThreadId, + assistantMsgId, + text: + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1318,6 +1358,7 @@ export default function NewChatPage() { tokenUsageStore, toolsWithUI, setPremiumAlertForThread, + persistAssistantErrorMessage, ] ); @@ -1589,7 +1630,9 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } @@ -1653,25 +1696,14 @@ export default function NewChatPage() { } else { toast.error("Failed to regenerate response. Please try again."); } - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [ - { - type: "text", - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? - "Sorry, there was an error. Please try again.", - }, - ], - } - : m - ) - ); + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1685,6 +1717,7 @@ export default function NewChatPage() { tokenUsageStore, toolsWithUI, setPremiumAlertForThread, + persistAssistantErrorMessage, ] ); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index ff8fdfbd4..9f2ac87a5 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -256,7 +256,7 @@ export type SSEEvent = }>; }; } - | { type: "error"; errorText: string }; + | { type: "error"; errorText: string; errorCode?: string }; /** * Async generator that reads an SSE stream and yields parsed JSON objects. From e6db050dfd6ae9d0bdd63597d261caf7151d8720 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 21:58:17 +0530 Subject: [PATCH 237/299] feat(chat): add userId to premium alert handling and improve alert visibility in UI --- surfsense_backend/app/tasks/chat/stream_new_chat.py | 4 ++-- .../new-chat/[[...chat_id]]/page.tsx | 3 +++ surfsense_web/atoms/chat/premium-alert.atom.ts | 12 ++++++++++++ surfsense_web/components/assistant-ui/thread.tsx | 8 ++++---- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 060dd23c6..ecc727b47 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1581,7 +1581,7 @@ async def stream_new_chat( ) else: yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model.", + "Buy more tokens to continue with this model, or switch to a free model", error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() @@ -2349,7 +2349,7 @@ async def stream_resume_chat( ) else: yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model.", + "Buy more tokens to continue with this model, or switch to a free model", error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f775e1f06..6ec587f91 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1032,6 +1032,7 @@ export default function NewChatPage() { setPremiumAlertForThread({ threadId: currentThreadId, message: premiumQuotaAlertMessage, + userId: currentUser?.id ?? null, }); } else { toast.error("Failed to get response. Please try again."); @@ -1334,6 +1335,7 @@ export default function NewChatPage() { setPremiumAlertForThread({ threadId: resumeThreadId, message: premiumQuotaAlertMessage, + userId: currentUser?.id ?? null, }); } else { toast.error("Failed to resume. Please try again."); @@ -1692,6 +1694,7 @@ export default function NewChatPage() { setPremiumAlertForThread({ threadId, message: premiumQuotaAlertMessage, + userId: currentUser?.id ?? null, }); } else { toast.error("Failed to regenerate response. Please try again."); diff --git a/surfsense_web/atoms/chat/premium-alert.atom.ts b/surfsense_web/atoms/chat/premium-alert.atom.ts index c0efc174f..1c837dd65 100644 --- a/surfsense_web/atoms/chat/premium-alert.atom.ts +++ b/surfsense_web/atoms/chat/premium-alert.atom.ts @@ -14,13 +14,25 @@ export const setPremiumAlertForThreadAtom = atom( payload: { threadId: number; message: string; + userId?: string | null; } ) => { + const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`; + + if (typeof window !== "undefined") { + const hasSeen = localStorage.getItem(storageKey) === "true"; + if (hasSeen) return; + } + const current = get(premiumAlertByThreadAtom); set(premiumAlertByThreadAtom, { ...current, [payload.threadId]: { message: payload.message }, }); + + if (typeof window !== "undefined") { + localStorage.setItem(storageKey, "true"); + } } ); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cb063fac3..3095556dc 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -161,15 +161,15 @@ const PremiumQuotaPinnedAlert: FC = () => { if (!alert) return null; return ( - <div className="mx-0 bg-amber-500/10 px-3 py-2 text-amber-100"> - <div className="flex items-start gap-2"> - <AlertCircle className="mt-0.5 size-4 shrink-0 text-amber-300" /> + <div className="mx-0 overflow-hidden rounded-2xl border-input bg-muted px-4 py-4 text-foreground select-none"> + <div className="flex items-center gap-2"> + <AlertCircle className="size-4 shrink-0 text-muted-foreground" /> <div className="min-w-0 flex-1"> <p className="text-sm">{alert.message}</p> </div> <button type="button" - className="inline-flex size-6 items-center justify-center text-amber-200 transition-colors hover:text-amber-50" + className="inline-flex size-6 items-center justify-center text-muted-foreground transition-colors hover:text-foreground" aria-label="Dismiss premium quota alert" onClick={() => clearPremiumAlertForThread(currentThreadId)} > From 222b27183fd9603637df2c31459dc74cc988ade9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:01:28 +0530 Subject: [PATCH 238/299] feat(chat): improve error handling and logging for premium quota exhaustion in chat operations --- .../new-chat/[[...chat_id]]/page.tsx | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 6ec587f91..a2985ab0c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1018,8 +1018,12 @@ export default function NewChatPage() { } return; } - console.error("[NewChatPage] Chat error:", error); const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + console.info("[NewChatPage] Premium quota exhausted:", error); + } else { + console.error("[NewChatPage] Chat error:", error); + } // Track chat error trackChatError( @@ -1329,8 +1333,12 @@ export default function NewChatPage() { if (error instanceof Error && error.name === "AbortError") { return; } - console.error("[NewChatPage] Resume error:", error); const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + console.info("[NewChatPage] Premium quota exhausted during resume:", error); + } else { + console.error("[NewChatPage] Resume error:", error); + } if (premiumQuotaAlertMessage) { setPremiumAlertForThread({ threadId: resumeThreadId, @@ -1357,6 +1365,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + currentUser?.id, tokenUsageStore, toolsWithUI, setPremiumAlertForThread, @@ -1683,8 +1692,12 @@ export default function NewChatPage() { return; } batcher.dispose(); - console.error("[NewChatPage] Regeneration error:", error); const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + console.info("[NewChatPage] Premium quota exhausted during regeneration:", error); + } else { + console.error("[NewChatPage] Regeneration error:", error); + } trackChatError( searchSpaceId, threadId, @@ -1717,6 +1730,7 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, + currentUser?.id, tokenUsageStore, toolsWithUI, setPremiumAlertForThread, From d64543686fe6304f99eac9e62bbb86944895840f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 11:56:41 +0530 Subject: [PATCH 239/299] feat(chat): unify error handling and logging for chat operations, enhancing clarity and consistency in error reporting --- .../app/routes/new_chat_routes.py | 1 + .../app/tasks/chat/stream_new_chat.py | 319 +++++++++++++++--- .../unit/test_stream_new_chat_contract.py | 119 +++++++ .../new-chat/[[...chat_id]]/page.tsx | 240 ++++++------- .../lib/chat/chat-error-classifier.ts | 273 +++++++++++++++ surfsense_web/lib/posthog/events.ts | 50 +++ 6 files changed, 831 insertions(+), 171 deletions(-) create mode 100644 surfsense_web/lib/chat/chat-error-classifier.ts diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5560d90d..0189dd139 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1524,6 +1524,7 @@ async def regenerate_response( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=regenerate_image_urls or None, + flow="regenerate", ): yield chunk streaming_completed = True diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index ecc727b47..a0be55c1b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -19,7 +19,7 @@ import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from uuid import UUID import anyio @@ -253,6 +253,98 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: + raw = str(exc) + parsed = _parse_error_payload(raw) + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + + if provider_error_type == "rate_limit_error": + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + ) + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -1397,6 +1489,7 @@ async def stream_new_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, + flow: Literal["new", "regenerate"] = "new", ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -1448,6 +1541,30 @@ async def stream_new_chat( _premium_reserved = 0 _premium_request_id: str | None = None + def _emit_stream_error( + *, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, + ) -> str: + _log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration @@ -1499,13 +1616,21 @@ async def stream_new_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _perf_log.info( @@ -1541,13 +1666,6 @@ async def stream_new_chat( ) _premium_reserved = reserve_amount if not quota_result.allowed: - logging.getLogger(__name__).info( - "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, - ) if requested_llm_config_id == 0: try: llm_config_id = ( @@ -1561,34 +1679,66 @@ async def stream_new_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _premium_request_id = None _premium_reserved = 0 - logging.getLogger(__name__).info( - "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, + _log_chat_stream_error( + flow=flow, + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, ) else: - yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model", + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return @@ -2097,12 +2247,25 @@ async def stream_new_chat( # Handle any errors import traceback + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2217,6 +2380,30 @@ async def stream_resume_chat( accumulator = start_turn() + def _emit_stream_error( + *, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, + ) -> str: + _log_chat_stream_error( + flow="resume", + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + session = async_session_maker() try: if user_id: @@ -2267,13 +2454,21 @@ async def stream_resume_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _perf_log.info( @@ -2309,13 +2504,6 @@ async def stream_resume_chat( ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: - logging.getLogger(__name__).info( - "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, - ) if requested_llm_config_id == 0: try: llm_config_id = ( @@ -2329,34 +2517,66 @@ async def stream_resume_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _resume_premium_request_id = None _resume_premium_reserved = 0 - logging.getLogger(__name__).info( - "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, ) else: - yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model", + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return @@ -2528,10 +2748,23 @@ async def stream_resume_chat( except Exception as e: import traceback + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 034aa484c..1f8168837 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -1,9 +1,18 @@ +import inspect +import json +import logging +from pathlib import Path +import re + import pytest +import app.tasks.chat.stream_new_chat as stream_new_chat_module from app.tasks.chat.stream_new_chat import ( StreamResult, + _classify_stream_exception, _contract_enforcement_active, _evaluate_file_contract_outcome, + _log_chat_stream_error, _tool_output_has_error, ) @@ -45,3 +54,113 @@ def test_contract_enforcement_local_only(): result.filesystem_mode = "cloud" assert not _contract_enforcement_active(result) + + +def _extract_chat_stream_payload(record_message: str) -> dict: + prefix = "[chat_stream_error] " + assert record_message.startswith(prefix) + return json.loads(record_message[len(prefix) :]) + + +def test_unified_chat_stream_error_log_schema(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="new", + error_kind="server_error", + error_code="SERVER_ERROR", + severity="warn", + is_expected=False, + request_id="req-123", + thread_id=101, + search_space_id=202, + user_id="user-1", + message="Error during chat: boom", + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + + required_keys = { + "event", + "flow", + "error_kind", + "error_code", + "severity", + "is_expected", + "request_id", + "thread_id", + "search_space_id", + "user_id", + "message", + } + assert required_keys.issubset(payload.keys()) + assert payload["event"] == "chat_stream_error" + assert payload["flow"] == "new" + assert payload["error_code"] == "SERVER_ERROR" + + +def test_premium_quota_uses_unified_chat_stream_log_shape(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id="req-premium", + thread_id=303, + search_space_id=404, + user_id="user-2", + message="Buy more tokens to continue with this model, or switch to a free model", + extra={"auto_fallback": False}, + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + assert payload["event"] == "chat_stream_error" + assert payload["error_kind"] == "premium_quota_exhausted" + assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED" + assert payload["flow"] == "resume" + assert payload["is_expected"] is True + assert payload["auto_fallback"] is False + + +def test_stream_error_emission_keeps_machine_error_codes(): + source = inspect.getsource(stream_new_chat_module) + format_error_calls = re.findall(r"format_error\(", source) + emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) + + # Both new/resume stream paths now route through local emitters that always + # pass a machine-readable error_code. + assert len(format_error_calls) == 2 + assert { + "PREMIUM_QUOTA_EXHAUSTED", + "SERVER_ERROR", + }.issubset(emitted_error_codes) + assert 'flow: Literal["new", "regenerate"] = "new"' in source + assert "flow=flow" in source + assert 'flow="resume"' in source + + +def test_stream_exception_classifies_rate_limited(): + exc = Exception( + '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' + ) + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + + +def test_premium_classification_is_error_code_driven(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + source = classifier_path.read_text(encoding="utf-8") + + assert "PREMIUM_KEYWORDS" not in source + assert "RATE_LIMIT_KEYWORDS" not in source + assert "normalized.includes(" not in source + assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index a2985ab0c..ffd58e660 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -49,6 +49,10 @@ import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; +import { + classifyChatError, + type ChatFlow, +} from "@/lib/chat/chat-error-classifier"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, @@ -84,7 +88,8 @@ import { import { NotFoundError } from "@/lib/error"; import { trackChatCreated, - trackChatError, + trackChatBlocked, + trackChatErrorDetailed, trackChatMessageSent, trackChatResponseReceived, } from "@/lib/posthog/events"; @@ -201,26 +206,6 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); -const PREMIUM_QUOTA_ASSISTANT_MESSAGE = - "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; - -function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { - if (!(error instanceof Error)) return null; - const withCode = error as Error & { errorCode?: string }; - if (withCode.errorCode === "PREMIUM_QUOTA_EXHAUSTED") { - return error.message; - } - const normalized = error.message.toLowerCase(); - if ( - !normalized.includes("premium tokens exhausted") - && !normalized.includes("premium token quota exceeded") - && !normalized.includes("buy more tokens") - ) { - return null; - } - return error.message; -} - export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -378,6 +363,81 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.chat_id]); + const handleChatFailure = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + }) => { + const normalized = classifyChatError({ + error, + flow, + context: { + searchSpaceId, + threadId, + }, + }); + + const logger = + normalized.severity === "error" + ? console.error + : normalized.severity === "warn" + ? console.warn + : console.info; + logger(`[NewChatPage] ${flow} ${normalized.kind}:`, error); + + const telemetryPayload = { + flow, + kind: normalized.kind, + error_code: normalized.errorCode, + severity: normalized.severity, + is_expected: normalized.isExpected, + message: normalized.userMessage, + }; + if (normalized.telemetryEvent === "chat_blocked") { + trackChatBlocked(searchSpaceId, threadId, telemetryPayload); + } else { + trackChatErrorDetailed(searchSpaceId, threadId, telemetryPayload); + } + + if (normalized.channel === "silent") { + return; + } + + if (normalized.channel === "pinned_inline") { + if (threadId) { + setPremiumAlertForThread({ + threadId, + message: normalized.userMessage, + userId: currentUser?.id ?? null, + }); + } + if (normalized.assistantMessage) { + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: normalized.assistantMessage, + }); + } + return; + } + + toast.error(normalized.userMessage); + }, + [ + currentUser?.id, + persistAssistantErrorMessage, + searchSpaceId, + setPremiumAlertForThread, + ] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -1018,36 +1078,11 @@ export default function NewChatPage() { } return; } - const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); - if (premiumQuotaAlertMessage) { - console.info("[NewChatPage] Premium quota exhausted:", error); - } else { - console.error("[NewChatPage] Chat error:", error); - } - - // Track chat error - trackChatError( - searchSpaceId, - currentThreadId, - error instanceof Error ? error.message : "Unknown error" - ); - - if (premiumQuotaAlertMessage) { - setPremiumAlertForThread({ - threadId: currentThreadId, - message: premiumQuotaAlertMessage, - userId: currentUser?.id ?? null, - }); - } else { - toast.error("Failed to get response. Please try again."); - } - await persistAssistantErrorMessage({ + await handleChatFailure({ + error, + flow: "new", threadId: currentThreadId, assistantMsgId, - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? "Sorry, there was an error. Please try again.", }); } finally { setIsRunning(false); @@ -1071,8 +1106,7 @@ export default function NewChatPage() { pendingUserImageUrls, setPendingUserImageUrls, toolsWithUI, - setPremiumAlertForThread, - persistAssistantErrorMessage, + handleChatFailure, ] ); @@ -1333,28 +1367,11 @@ export default function NewChatPage() { if (error instanceof Error && error.name === "AbortError") { return; } - const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); - if (premiumQuotaAlertMessage) { - console.info("[NewChatPage] Premium quota exhausted during resume:", error); - } else { - console.error("[NewChatPage] Resume error:", error); - } - if (premiumQuotaAlertMessage) { - setPremiumAlertForThread({ - threadId: resumeThreadId, - message: premiumQuotaAlertMessage, - userId: currentUser?.id ?? null, - }); - } else { - toast.error("Failed to resume. Please try again."); - } - await persistAssistantErrorMessage({ + await handleChatFailure({ + error, + flow: "resume", threadId: resumeThreadId, assistantMsgId, - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? "Sorry, there was an error. Please try again.", }); } finally { setIsRunning(false); @@ -1365,11 +1382,9 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, - currentUser?.id, tokenUsageStore, toolsWithUI, - setPremiumAlertForThread, - persistAssistantErrorMessage, + handleChatFailure, ] ); @@ -1491,15 +1506,6 @@ export default function NewChatPage() { userQueryToDisplay = newUserQuery; } - // Remove the last two messages (user + assistant) from the UI immediately - // The backend will also delete them from the database - setMessages((prev) => { - if (prev.length >= 2) { - return prev.slice(0, -2); - } - return prev; - }); - // Start streaming setIsRunning(true); const controller = new AbortController(); @@ -1530,19 +1536,9 @@ export default function NewChatPage() { createdAt: new Date(), metadata: isEdit ? undefined : originalUserMessageMetadata, }; - setMessages((prev) => [...prev, userMessage]); - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); - + const userContentToPersist = isEdit + ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) + : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; try { const selection = await getAgentFilesystemSelection(searchSpaceId); const requestBody: Record<string, unknown> = { @@ -1570,6 +1566,22 @@ export default function NewChatPage() { throw new Error(`Backend error: ${response.status}`); } + // Only switch UI to regenerated placeholder messages after the backend accepts + // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). + setMessages((prev) => { + const base = prev.length >= 2 ? prev.slice(0, -2) : prev; + return [ + ...base, + userMessage, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]; + }); + const flushMessages = () => { setMessages((prev) => prev.map((m) => @@ -1654,10 +1666,6 @@ export default function NewChatPage() { if (contentParts.length > 0) { try { // Persist user message (for both edit and reload modes, since backend deleted it) - const userContentToPersist = isEdit - ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) - : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; - const savedUserMessage = await appendMessage(threadId, { role: "user", content: userContentToPersist, @@ -1692,33 +1700,11 @@ export default function NewChatPage() { return; } batcher.dispose(); - const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); - if (premiumQuotaAlertMessage) { - console.info("[NewChatPage] Premium quota exhausted during regeneration:", error); - } else { - console.error("[NewChatPage] Regeneration error:", error); - } - trackChatError( - searchSpaceId, - threadId, - error instanceof Error ? error.message : "Unknown error" - ); - if (premiumQuotaAlertMessage) { - setPremiumAlertForThread({ - threadId, - message: premiumQuotaAlertMessage, - userId: currentUser?.id ?? null, - }); - } else { - toast.error("Failed to regenerate response. Please try again."); - } - await persistAssistantErrorMessage({ + await handleChatFailure({ + error, + flow: "regenerate", threadId, assistantMsgId, - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? "Sorry, there was an error. Please try again.", }); } finally { setIsRunning(false); @@ -1730,11 +1716,9 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, - currentUser?.id, tokenUsageStore, toolsWithUI, - setPremiumAlertForThread, - persistAssistantErrorMessage, + handleChatFailure, ] ); diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts new file mode 100644 index 000000000..dc9bb09df --- /dev/null +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -0,0 +1,273 @@ +export type ChatFlow = "new" | "resume" | "regenerate"; + +export type ChatErrorKind = + | "premium_quota_exhausted" + | "auth_expired" + | "rate_limited" + | "network_offline" + | "stream_interrupted" + | "stream_parse_error" + | "tool_execution_error" + | "persist_message_failed" + | "server_error" + | "unknown"; + +export type ChatErrorChannel = "pinned_inline" | "toast" | "silent"; +export type ChatTelemetryEvent = "chat_blocked" | "chat_error"; +export type ChatErrorSeverity = "info" | "warn" | "error"; + +export interface NormalizedChatError { + kind: ChatErrorKind; + channel: ChatErrorChannel; + severity: ChatErrorSeverity; + telemetryEvent: ChatTelemetryEvent; + isExpected: boolean; + userMessage: string; + assistantMessage?: string; + rawMessage?: string; + errorCode?: string; + details?: Record<string, unknown>; +} + +export interface RawChatErrorInput { + error: unknown; + flow: ChatFlow; + context?: { + searchSpaceId?: number; + threadId?: number | null; + }; +} + +export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = + "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + +function getErrorMessage(error: unknown): string { + if (error instanceof Error) return error.message; + if (typeof error === "string") return error; + try { + return JSON.stringify(error); + } catch { + return "Unknown error"; + } +} + +function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string }; + if (withCode.errorCode) return withCode.errorCode; + } + + if (typeof error === "object" && error !== null) { + const withCode = error as { errorCode?: unknown }; + if (typeof withCode.errorCode === "string" && withCode.errorCode) { + return withCode.errorCode; + } + } + + if (parsedJson) { + const topLevelCode = parsedJson.errorCode; + if (typeof topLevelCode === "string" && topLevelCode) { + return topLevelCode; + } + } + + return undefined; +} + +function parseEmbeddedJson(text: string): Record<string, unknown> | null { + const candidates = [text]; + const firstBraceIdx = text.indexOf("{"); + if (firstBraceIdx >= 0) { + candidates.push(text.slice(firstBraceIdx)); + } + for (const candidate of candidates) { + try { + const parsed = JSON.parse(candidate); + if (typeof parsed === "object" && parsed !== null) { + return parsed as Record<string, unknown>; + } + } catch { + // noop + } + } + return null; +} + +function inferProviderErrorType(parsedJson: Record<string, unknown> | null): string | undefined { + if (!parsedJson) return undefined; + const topLevelType = parsedJson.type; + if (typeof topLevelType === "string" && topLevelType) return topLevelType; + const nestedError = parsedJson.error; + if (typeof nestedError === "object" && nestedError !== null) { + const nestedType = (nestedError as Record<string, unknown>).type; + if (typeof nestedType === "string" && nestedType) return nestedType; + } + return undefined; +} + +export function classifyChatError(input: RawChatErrorInput): NormalizedChatError { + const { error } = input; + const rawMessage = getErrorMessage(error); + const parsedJson = parseEmbeddedJson(rawMessage); + const errorCode = getErrorCode(error, parsedJson); + const providerErrorType = inferProviderErrorType(parsedJson); + const providerTypeNormalized = providerErrorType?.toLowerCase() ?? ""; + const errorName = error instanceof Error ? error.name : undefined; + + if (errorName === "AbortError") { + return { + kind: "stream_interrupted", + channel: "silent", + severity: "info", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Request canceled.", + rawMessage, + errorCode, + details: { flow: input.flow }, + }; + } + + if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") { + return { + kind: "premium_quota_exhausted", + channel: "pinned_inline", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "Buy more tokens to continue with this model, or switch to a free model.", + assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE, + rawMessage, + errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "AUTH_EXPIRED" || + errorCode === "UNAUTHORIZED" + ) { + return { + kind: "auth_expired", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Your session expired. Please sign in again.", + rawMessage, + errorCode: errorCode ?? "AUTH_EXPIRED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "RATE_LIMITED" || + providerTypeNormalized === "rate_limit_error" + ) { + return { + kind: "rate_limited", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + rawMessage, + errorCode: errorCode ?? "RATE_LIMITED", + details: { flow: input.flow, providerErrorType }, + }; + } + + if ( + errorCode === "NETWORK_ERROR" + ) { + return { + kind: "network_offline", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Connection issue detected. Check your internet and try again.", + rawMessage, + errorCode: errorCode ?? "NETWORK_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "STREAM_PARSE_ERROR" + ) { + return { + kind: "stream_parse_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We hit a response formatting issue. Please try again.", + rawMessage, + errorCode: errorCode ?? "STREAM_PARSE_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "TOOL_EXECUTION_ERROR" + ) { + return { + kind: "tool_execution_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "A tool failed while processing your request. Please try again.", + rawMessage, + errorCode: errorCode ?? "TOOL_EXECUTION_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "PERSIST_MESSAGE_FAILED" + ) { + return { + kind: "persist_message_failed", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "Response generated, but saving failed. Please retry once.", + rawMessage, + errorCode: errorCode ?? "PERSIST_MESSAGE_FAILED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "SERVER_ERROR" + ) { + return { + kind: "server_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We couldn’t complete this response right now. Please try again.", + rawMessage, + errorCode: errorCode ?? "SERVER_ERROR", + details: { flow: input.flow, providerErrorType }, + }; + } + + return { + kind: "unknown", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We couldn’t complete this response right now. Please try again.", + rawMessage, + errorCode, + details: { flow: input.flow, providerErrorType }, + }; +} diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 34ed3044d..30e58215a 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -1,5 +1,6 @@ import posthog from "posthog-js"; import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants"; +import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier"; /** * PostHog Analytics Event Definitions @@ -139,6 +140,55 @@ export function trackChatError(searchSpaceId: number, chatId: number, error?: st }); } +export interface ChatFailureTelemetry { + flow: ChatFlow; + kind: ChatErrorKind; + error_code?: string; + severity: ChatErrorSeverity; + is_expected: boolean; + message?: string; +} + +export function trackChatBlocked( + searchSpaceId: number, + chatId: number | null, + payload: ChatFailureTelemetry +) { + safeCapture( + "chat_blocked", + compact({ + search_space_id: searchSpaceId, + chat_id: chatId ?? undefined, + flow: payload.flow, + kind: payload.kind, + error_code: payload.error_code, + severity: payload.severity, + is_expected: payload.is_expected, + message: payload.message, + }) + ); +} + +export function trackChatErrorDetailed( + searchSpaceId: number, + chatId: number | null, + payload: ChatFailureTelemetry +) { + safeCapture( + "chat_error", + compact({ + search_space_id: searchSpaceId, + chat_id: chatId ?? undefined, + flow: payload.flow, + kind: payload.kind, + error_code: payload.error_code, + severity: payload.severity, + is_expected: payload.is_expected, + message: payload.message, + }) + ); +} + /** * Track a message sent from the unauthenticated "free" / anonymous chat * flow. This is intentionally a separate event from `chat_message_sent` From fd4d0817d14939f0c2c9421dabc6b83213d7a17f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:38:11 +0530 Subject: [PATCH 240/299] feat(chat): implement comprehensive error handling for chat operations, including detailed response parsing and improved user message persistence --- .../new-chat/[[...chat_id]]/page.tsx | 262 ++++++++++++------ 1 file changed, 180 insertions(+), 82 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index ffd58e660..b6afaf131 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -222,6 +222,7 @@ export default function NewChatPage() { interruptData: Record<string, unknown>; } | null>(null); const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const persistAssistantErrorMessage = useCallback( async ({ @@ -267,14 +268,107 @@ export default function NewChatPage() { [tokenUsageStore] ); + const persistUserTurn = useCallback( + async ({ + threadId, + userMsgId, + content, + mentionedDocs, + logContext, + }: { + threadId: number | null; + userMsgId: string; + content: unknown; + mentionedDocs?: MentionedDocumentInfo[]; + logContext: string; + }) => { + if (!threadId) return null; + try { + const normalizedContent = Array.isArray(content) + ? ([...content] as unknown[]) + : [content]; + const hasMentionedDocumentsPart = normalizedContent.some((part) => + MentionedDocumentsPartSchema.safeParse(part).success + ); + if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) { + normalizedContent.push({ + type: "mentioned-documents", + documents: mentionedDocs, + }); + } + + const savedUserMessage = await appendMessage(threadId, { + role: "user", + content: normalizedContent as AppendMessage["content"], + }); + const newUserMsgId = `msg-${savedUserMessage.id}`; + setMessages((prev) => + prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) + ); + if (mentionedDocs && mentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + const { [userMsgId]: _, ...rest } = prev; + return { + ...rest, + [newUserMsgId]: mentionedDocs, + }; + }); + } + return newUserMsgId; + } catch (err) { + console.error(`Failed to persist ${logContext} user message:`, err); + return null; + } + }, + [setMessageDocumentsMap] + ); + + const persistAssistantTurn = useCallback( + async ({ + threadId, + assistantMsgId, + content, + tokenUsage, + logContext, + onRemapped, + }: { + threadId: number | null; + assistantMsgId: string; + content: unknown; + tokenUsage?: Record<string, unknown>; + logContext: string; + onRemapped?: (newMsgId: string) => void; + }) => { + if (!threadId) return null; + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: content as AppendMessage["content"], + token_usage: tokenUsage, + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + onRemapped?.(newMsgId); + return newMsgId; + } catch (err) { + console.error(`Failed to persist ${logContext} assistant message:`, err); + return null; + } + }, + [tokenUsageStore] + ); + // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); // Get mentioned document IDs from the composer. const mentionedDocumentIds = useAtomValue(mentionedDocumentIdsAtom); const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); + const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); - const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setCurrentThreadState = useSetAtom(currentThreadAtom); const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); @@ -1023,29 +1117,20 @@ export default function NewChatPage() { // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0 && !wasInterrupted) { - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - }); - - // Update message ID from temporary to database ID so comments work immediately - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - - // Update pending interrupt with the new persisted message ID - setPendingInterrupt((prev) => - prev && prev.assistantMsgId === assistantMsgId - ? { ...prev, assistantMsgId: newMsgId } - : prev - ); - } catch (err) { - console.error("Failed to persist assistant message:", err); - } + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "new chat", + onRemapped: (newMsgId) => { + setPendingInterrupt((prev) => + prev && prev.assistantMsgId === assistantMsgId + ? { ...prev, assistantMsgId: newMsgId } + : prev + ); + }, + }); // Track successful response trackChatResponseReceived(searchSpaceId, currentThreadId); @@ -1061,20 +1146,12 @@ export default function NewChatPage() { ); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: partialContent, - }); - - // Update message ID from temporary to database ID - const newMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - } catch (err) { - console.error("Failed to persist partial assistant message:", err); - } + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial new chat", + }); } return; } @@ -1107,6 +1184,7 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, handleChatFailure, + persistAssistantTurn, ] ); @@ -1347,20 +1425,13 @@ export default function NewChatPage() { const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0) { - try { - const savedMessage = await appendMessage(resumeThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - }); - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - } catch (err) { - console.error("Failed to persist resumed assistant message:", err); - } + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "resumed chat", + }); } } catch (error) { batcher.dispose(); @@ -1385,6 +1456,7 @@ export default function NewChatPage() { tokenUsageStore, toolsWithUI, handleChatFailure, + persistAssistantTurn, ] ); @@ -1462,6 +1534,7 @@ export default function NewChatPage() { editExtras?: { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; + sourceUserMessageId?: string; } ) => { if (!threadId) { @@ -1487,11 +1560,13 @@ export default function NewChatPage() { let userQueryToDisplay: string | undefined; let originalUserMessageContent: ThreadMessageLike["content"] | null = null; let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined; + let sourceUserMessageId: string | undefined = editExtras?.sourceUserMessageId; if (!isEdit) { // Reload mode - find and preserve the last user message content const lastUserMessage = [...messages].reverse().find((m) => m.role === "user"); if (lastUserMessage) { + sourceUserMessageId = lastUserMessage.id; originalUserMessageContent = lastUserMessage.content; originalUserMessageMetadata = lastUserMessage.metadata; // Extract text for the API request @@ -1524,6 +1599,8 @@ export default function NewChatPage() { const { contentParts, toolCallIndices } = contentPartsState; const batcher = new FrameBatchedUpdater(); let tokenUsageData: Record<string, unknown> | null = null; + let regenerateAccepted = false; + let userPersisted = false; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1539,6 +1616,10 @@ export default function NewChatPage() { const userContentToPersist = isEdit ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; + const sourceMentionedDocs = + sourceUserMessageId && messageDocumentsMap[sourceUserMessageId] + ? messageDocumentsMap[sourceUserMessageId] + : []; try { const selection = await getAgentFilesystemSelection(searchSpaceId); const requestBody: Record<string, unknown> = { @@ -1565,6 +1646,7 @@ export default function NewChatPage() { if (!response.ok) { throw new Error(`Backend error: ${response.status}`); } + regenerateAccepted = true; // Only switch UI to regenerated placeholder messages after the backend accepts // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). @@ -1581,6 +1663,12 @@ export default function NewChatPage() { }, ]; }); + if (sourceMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => ({ + ...prev, + [userMsgId]: sourceMentionedDocs, + })); + } const flushMessages = () => { setMessages((prev) => @@ -1664,47 +1752,45 @@ export default function NewChatPage() { // Persist messages after streaming completes const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0) { - try { - // Persist user message (for both edit and reload modes, since backend deleted it) - const savedUserMessage = await appendMessage(threadId, { - role: "user", - content: userContentToPersist, - }); + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated", + }); + userPersisted = Boolean(persistedUserMsgId); - // Update user message ID to database ID - const newUserMsgId = `msg-${savedUserMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "regenerated", + }); - // Persist assistant message - const savedMessage = await appendMessage(threadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - }); - - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - - trackChatResponseReceived(searchSpaceId, threadId); - } catch (err) { - console.error("Failed to persist regenerated message:", err); - } + trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { if (error instanceof Error && error.name === "AbortError") { return; } batcher.dispose(); + if (regenerateAccepted && !userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + } await handleChatFailure({ error, flow: "regenerate", threadId, - assistantMsgId, + assistantMsgId: regenerateAccepted ? assistantMsgId : "no-persist-assistant", }); } finally { setIsRunning(false); @@ -1716,9 +1802,13 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, + messageDocumentsMap, + setMessageDocumentsMap, tokenUsageStore, toolsWithUI, handleChatFailure, + persistAssistantTurn, + persistUserTurn, ] ); @@ -1733,7 +1823,15 @@ export default function NewChatPage() { } const userMessageContent = message.content as unknown as ThreadMessageLike["content"]; - await handleRegenerate(queryForApi, { userMessageContent, userImages }); + const sourceUserMessageId = + typeof (message as { id?: unknown }).id === "string" + ? ((message as { id?: string }).id ?? undefined) + : undefined; + await handleRegenerate(queryForApi, { + userMessageContent, + userImages, + sourceUserMessageId, + }); }, [handleRegenerate] ); From 35ea0eae53a24875e368111f294b366f48f2d9fa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:03:09 +0530 Subject: [PATCH 241/299] feat(chat): enhance error classification and handling for thread busy scenarios, improving user feedback and response management --- .../app/tasks/chat/stream_new_chat.py | 106 +++++---- .../unit/test_stream_new_chat_contract.py | 33 ++- .../new-chat/[[...chat_id]]/page.tsx | 209 +++++++++++++----- .../components/free-chat/anonymous-chat.tsx | 16 +- .../components/free-chat/free-chat-page.tsx | 52 ++++- .../lib/chat/chat-error-classifier.ts | 17 ++ 6 files changed, 322 insertions(+), 111 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index a0be55c1b..d6ca5418c 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -19,6 +19,7 @@ import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field +from functools import partial from typing import Any, Literal from uuid import UUID @@ -30,6 +31,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.errors import BusyError from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -315,6 +317,15 @@ def _classify_stream_exception( flow_label: str, ) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + ) + parsed = _parse_error_payload(raw) provider_error_type = "" if parsed: @@ -345,6 +356,37 @@ def _classify_stream_exception( ) +def _emit_stream_terminal_error( + *, + streaming_service: VercelStreamingService, + flow: str, + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + _log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -1541,29 +1583,15 @@ async def stream_new_chat( _premium_reserved = 0 _premium_request_id: str | None = None - def _emit_stream_error( - *, - message: str, - error_kind: str = "server_error", - error_code: str = "SERVER_ERROR", - severity: Literal["info", "warn", "error"] = "error", - is_expected: bool = False, - extra: dict[str, Any] | None = None, - ) -> str: - _log_chat_stream_error( - flow=flow, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=message, - extra=extra, - ) - return streaming_service.format_error(message, error_code=error_code) + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow=flow, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) session = async_session_maker() try: @@ -2380,29 +2408,15 @@ async def stream_resume_chat( accumulator = start_turn() - def _emit_stream_error( - *, - message: str, - error_kind: str = "server_error", - error_code: str = "SERVER_ERROR", - severity: Literal["info", "warn", "error"] = "error", - is_expected: bool = False, - extra: dict[str, Any] | None = None, - ) -> str: - _log_chat_stream_error( - flow="resume", - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=message, - extra=extra, - ) - return streaming_service.format_error(message, error_code=error_code) + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow="resume", + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) session = async_session_maker() try: diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 1f8168837..125177084 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -1,12 +1,13 @@ import inspect import json import logging -from pathlib import Path import re +from pathlib import Path import pytest import app.tasks.chat.stream_new_chat as stream_new_chat_module +from app.agents.new_chat.errors import BusyError from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, @@ -130,14 +131,14 @@ def test_stream_error_emission_keeps_machine_error_codes(): format_error_calls = re.findall(r"format_error\(", source) emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) - # Both new/resume stream paths now route through local emitters that always - # pass a machine-readable error_code. - assert len(format_error_calls) == 2 + # All stream paths should route through one shared terminal error emitter. + assert len(format_error_calls) == 1 assert { "PREMIUM_QUOTA_EXHAUSTED", "SERVER_ERROR", }.issubset(emitted_error_codes) assert 'flow: Literal["new", "regenerate"] = "new"' in source + assert "_emit_stream_terminal_error" in source assert "flow=flow" in source assert 'flow="resume"' in source @@ -156,6 +157,30 @@ def test_stream_exception_classifies_rate_limited(): assert "temporarily rate-limited" in user_message +def test_stream_exception_classifies_thread_busy(): + exc = BusyError(request_id="thread-123") + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + + +def test_stream_exception_classifies_thread_busy_from_message(): + exc = Exception("Thread is busy with another request") + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + + def test_premium_classification_is_error_code_driven(): classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" source = classifier_path.read_text(encoding="utf-8") diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index b6afaf131..70e188612 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -67,6 +67,7 @@ import { type ContentPartsState, FrameBatchedUpdater, readSSEStream, + type SSEEvent, type ThinkingStepData, updateThinkingSteps, updateToolCall, @@ -136,6 +137,75 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un } } +function toStreamTerminalError( + event: Extract<SSEEvent, { type: "error" }> +): Error & { errorCode?: string } { + return Object.assign(new Error(event.errorText || "Server error"), { + errorCode: event.errorCode, + }); +} + +async function toHttpResponseError(response: Response): Promise<Error & { errorCode?: string }> { + const statusDefaultCode = + response.status === 409 + ? "THREAD_BUSY" + : response.status === 429 + ? "RATE_LIMITED" + : response.status === 401 || response.status === 403 + ? "AUTH_EXPIRED" + : "SERVER_ERROR"; + + let rawBody = ""; + try { + rawBody = await response.text(); + } catch { + // noop + } + + let parsedBody: Record<string, unknown> | null = null; + if (rawBody) { + try { + const parsed = JSON.parse(rawBody); + if (typeof parsed === "object" && parsed !== null) { + parsedBody = parsed as Record<string, unknown>; + } + } catch { + // noop + } + } + + const detail = parsedBody?.detail; + const detailObject = + typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null; + const detailMessage = typeof detail === "string" ? detail : undefined; + const topLevelMessage = + typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; + const detailNestedMessage = + typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; + + const topLevelCode = + typeof parsedBody?.errorCode === "string" + ? parsedBody.errorCode + : typeof parsedBody?.error_code === "string" + ? parsedBody.error_code + : undefined; + const detailCode = + typeof detailObject?.errorCode === "string" + ? detailObject.errorCode + : typeof detailObject?.error_code === "string" + ? detailObject.error_code + : undefined; + + const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + const message = + detailNestedMessage ?? + detailMessage ?? + topLevelMessage ?? + `Backend error: ${response.status}`; + + return Object.assign(new Error(message), { errorCode }); +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -532,6 +602,43 @@ export default function NewChatPage() { ] ); + const handleStreamTerminalError = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + accepted, + onAbort, + onAcceptedStreamError, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + accepted: boolean; + onAbort?: () => Promise<void>; + onAcceptedStreamError?: () => Promise<void>; + }) => { + if (error instanceof Error && error.name === "AbortError") { + await onAbort?.(); + return; + } + + if (accepted) { + await onAcceptedStreamError?.(); + } + + await handleChatFailure({ + error, + flow, + threadId, + assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", + }); + }, + [handleChatFailure] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -880,6 +987,7 @@ export default function NewChatPage() { const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; let tokenUsageData: Record<string, unknown> | null = null; + let newAccepted = false; // Add placeholder assistant message setMessages((prev) => [ @@ -951,8 +1059,9 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + newAccepted = true; const flushMessages = () => { setMessages((prev) => @@ -1106,9 +1215,7 @@ export default function NewChatPage() { break; case "error": - throw Object.assign(new Error(parsed.errorText || "Server error"), { - errorCode: parsed.errorCode, - }); + throw toStreamTerminalError(parsed); } } @@ -1137,29 +1244,29 @@ export default function NewChatPage() { } } catch (error) { batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - // Request was cancelled by user - persist partial response if any content was received - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) - ); - if (hasContent && currentThreadId) { - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - await persistAssistantTurn({ - threadId: currentThreadId, - assistantMsgId, - content: partialContent, - logContext: "partial new chat", - }); - } - return; - } - await handleChatFailure({ + await handleStreamTerminalError({ error, flow: "new", threadId: currentThreadId, assistantMsgId, + accepted: newAccepted, + onAbort: async () => { + // Request was cancelled by user - persist partial response if any content was received + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (hasContent && currentThreadId) { + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial new chat", + }); + } + }, }); } finally { setIsRunning(false); @@ -1183,7 +1290,7 @@ export default function NewChatPage() { pendingUserImageUrls, setPendingUserImageUrls, toolsWithUI, - handleChatFailure, + handleStreamTerminalError, persistAssistantTurn, ] ); @@ -1221,6 +1328,7 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let tokenUsageData: Record<string, unknown> | null = null; + let resumeAccepted = false; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1302,8 +1410,9 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + resumeAccepted = true; const flushMessages = () => { setMessages((prev) => @@ -1415,9 +1524,7 @@ export default function NewChatPage() { break; case "error": - throw Object.assign(new Error(parsed.errorText || "Server error"), { - errorCode: parsed.errorCode, - }); + throw toStreamTerminalError(parsed); } } @@ -1435,14 +1542,12 @@ export default function NewChatPage() { } } catch (error) { batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - return; - } - await handleChatFailure({ + await handleStreamTerminalError({ error, flow: "resume", threadId: resumeThreadId, assistantMsgId, + accepted: resumeAccepted, }); } finally { setIsRunning(false); @@ -1455,7 +1560,7 @@ export default function NewChatPage() { searchSpaceId, tokenUsageStore, toolsWithUI, - handleChatFailure, + handleStreamTerminalError, persistAssistantTurn, ] ); @@ -1644,7 +1749,7 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } regenerateAccepted = true; @@ -1741,9 +1846,7 @@ export default function NewChatPage() { break; case "error": - throw Object.assign(new Error(parsed.errorText || "Server error"), { - errorCode: parsed.errorCode, - }); + throw toStreamTerminalError(parsed); } } @@ -1772,25 +1875,25 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - if (error instanceof Error && error.name === "AbortError") { - return; - } batcher.dispose(); - if (regenerateAccepted && !userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId, - userMsgId, - content: userContentToPersist, - mentionedDocs: sourceMentionedDocs, - logContext: "regenerated (stream error)", - }); - userPersisted = Boolean(persistedUserMsgId); - } - await handleChatFailure({ + await handleStreamTerminalError({ error, flow: "regenerate", threadId, - assistantMsgId: regenerateAccepted ? assistantMsgId : "no-persist-assistant", + assistantMsgId, + accepted: regenerateAccepted, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + }, }); } finally { setIsRunning(false); @@ -1806,7 +1909,7 @@ export default function NewChatPage() { setMessageDocumentsMap, tokenUsageStore, toolsWithUI, - handleChatFailure, + handleStreamTerminalError, persistAssistantTurn, persistUserTurn, ] diff --git a/surfsense_web/components/free-chat/anonymous-chat.tsx b/surfsense_web/components/free-chat/anonymous-chat.tsx index b286c5316..3de2ca434 100644 --- a/surfsense_web/components/free-chat/anonymous-chat.tsx +++ b/surfsense_web/components/free-chat/anonymous-chat.tsx @@ -104,7 +104,13 @@ export function AnonymousChat({ model }: AnonymousChatProps) { setMessages((prev) => prev.filter((m) => m.id !== assistantId)); return; } - throw new Error(`Stream error: ${response.status}`); + const body = await response.text().catch(() => ""); + const errorCode = response.status === 409 ? "THREAD_BUSY" : "SERVER_ERROR"; + const message = + errorCode === "THREAD_BUSY" + ? "A previous response is still stopping. Please try again in a moment." + : `Stream error: ${response.status}`; + throw Object.assign(new Error(body || message), { errorCode }); } for await (const event of readSSEStream(response)) { @@ -115,10 +121,12 @@ export function AnonymousChat({ model }: AnonymousChatProps) { prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + event.delta } : m)) ); } else if (event.type === "error") { + const message = + event.errorCode === "THREAD_BUSY" + ? "A previous response is still stopping. Please try again in a moment." + : event.errorText; setMessages((prev) => - prev.map((m) => - m.id === assistantId ? { ...m, content: m.content || event.errorText } : m - ) + prev.map((m) => (m.id === assistantId ? { ...m, content: m.content || message } : m)) ); } else if ("type" in event && event.type === "data-token-usage") { // After streaming completes, refresh quota diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index deac1fd00..dd6693b35 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -48,6 +48,48 @@ function parseCaptchaError(status: number, body: string): string | null { return null; } +function normalizeFreeChatErrorMessage(error: unknown): string { + if (!(error instanceof Error)) return "An unexpected error occurred"; + const code = (error as Error & { errorCode?: string }).errorCode; + if (code === "THREAD_BUSY") { + return "A previous response is still stopping. Please try again in a moment."; + } + return error.message || "An unexpected error occurred"; +} + +function toFreeChatHttpError(status: number, body: string): Error & { errorCode?: string } { + let errorCode: string | undefined; + let message = body || `Server error: ${status}`; + try { + const parsed = JSON.parse(body) as Record<string, unknown>; + const detail = + typeof parsed.detail === "object" && parsed.detail !== null + ? (parsed.detail as Record<string, unknown>) + : null; + errorCode = + (typeof detail?.error_code === "string" ? detail.error_code : undefined) ?? + (typeof detail?.errorCode === "string" ? detail.errorCode : undefined) ?? + (typeof parsed.error_code === "string" ? parsed.error_code : undefined) ?? + (typeof parsed.errorCode === "string" ? parsed.errorCode : undefined); + message = + (typeof detail?.message === "string" ? detail.message : undefined) ?? + (typeof parsed.message === "string" ? parsed.message : undefined) ?? + (typeof parsed.detail === "string" ? parsed.detail : undefined) ?? + message; + } catch { + // non-json response + } + + if (!errorCode) { + if (status === 409) errorCode = "THREAD_BUSY"; + else if (status === 429) errorCode = "RATE_LIMITED"; + else if (status === 401 || status === 403) errorCode = "AUTH_EXPIRED"; + else errorCode = "SERVER_ERROR"; + } + + return Object.assign(new Error(message), { errorCode }); +} + export function FreeChatPage() { const anonMode = useAnonymousMode(); const modelSlug = anonMode.isAnonymous ? anonMode.modelSlug : ""; @@ -117,7 +159,7 @@ export function FreeChatPage() { const body = await response.text().catch(() => ""); const captchaCode = parseCaptchaError(response.status, body); if (captchaCode) return "captcha"; - throw new Error(body || `Server error: ${response.status}`); + throw toFreeChatHttpError(response.status, body); } const currentThinkingSteps = new Map<string, ThinkingStepData>(); @@ -187,7 +229,9 @@ export function FreeChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } batcher.flush(); @@ -277,7 +321,7 @@ export function FreeChatPage() { } catch (error) { if (error instanceof Error && error.name === "AbortError") return; console.error("[FreeChatPage] Chat error:", error); - const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; + const errorText = normalizeFreeChatErrorMessage(error); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -336,7 +380,7 @@ export function FreeChatPage() { } catch (error) { if (error instanceof Error && error.name === "AbortError") return; console.error("[FreeChatPage] Retry error:", error); - const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; + const errorText = normalizeFreeChatErrorMessage(error); setMessages((prev) => prev.map((m) => m.id === assistantMsgId diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index dc9bb09df..4341f7dc5 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -2,6 +2,7 @@ export type ChatFlow = "new" | "resume" | "regenerate"; export type ChatErrorKind = | "premium_quota_exhausted" + | "thread_busy" | "auth_expired" | "rate_limited" | "network_offline" @@ -144,6 +145,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if ( + errorCode === "THREAD_BUSY" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "THREAD_BUSY", + details: { flow: input.flow }, + }; + } + if ( errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED" From f60e742facdd2933d56958fe82de923c7aefab0f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:58:56 +0530 Subject: [PATCH 242/299] feat(chat): implement pre-accept failure handling and unified retry messaging for chat operations, enhancing user experience and error management --- .../unit/test_stream_new_chat_contract.py | 72 ++++++++ .../new-chat/[[...chat_id]]/page.tsx | 173 ++++++++++++++---- .../lib/chat/chat-error-classifier.ts | 24 ++- 3 files changed, 229 insertions(+), 40 deletions(-) diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 125177084..9f4280063 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -189,3 +189,75 @@ def test_premium_classification_is_error_code_driven(): assert "RATE_LIMIT_KEYWORDS" not in source assert "normalized.includes(" not in source assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source + + +def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + assert "onPreAcceptFailure?: () => Promise<void>;" in source + assert "if (!accepted) {" in source + assert "await onPreAcceptFailure?.();" in source + assert "await onAcceptedStreamError?.();" in source + assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source + assert "setMessageDocumentsMap((prev) => {" in source + + +def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): + user_message_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx" + ) + source = user_message_path.read_text(encoding="utf-8") + + assert "Not sent. Edit and retry." not in source + assert "failed_pre_accept" not in source + + +def test_network_send_failures_use_unified_retry_toast_message(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_source = classifier_path.read_text(encoding="utf-8") + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + page_source = page_path.read_text(encoding="utf-8") + + assert '"send_failed_pre_accept"' in classifier_source + assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert "if (withCode.code) return withCode.code;" in classifier_source + assert 'userMessage: "Message not sent. Please retry."' in classifier_source + assert 'userMessage: "Connection issue. Please try again."' in classifier_source + assert "tagPreAcceptSendFailure(error)" in page_source + assert 'existingCode === "THREAD_BUSY"' in page_source + assert 'existingCode === "AUTH_EXPIRED"' in page_source + assert 'existingCode === "UNAUTHORIZED"' in page_source + assert 'existingCode === "RATE_LIMITED"' in page_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source + assert 'errorCode: "NETWORK_ERROR"' not in page_source + assert "Failed to start chat. Please try again." not in page_source + + +def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + # Each flow tracks accepted boundary and passes it into shared terminal handling. + assert "let newAccepted = false;" in source + assert "let resumeAccepted = false;" in source + assert "let regenerateAccepted = false;" in source + assert "accepted: newAccepted," in source + assert "accepted: resumeAccepted," in source + assert "accepted: regenerateAccepted," in source + + # Pre-accept abort in resume/regenerate exits without persistence. + assert "if (!resumeAccepted) return;" in source + assert "if (!regenerateAccepted) return;" in source + + # New flow persists only when accepted and not already persisted. + assert "if (newAccepted && !userPersisted) {" in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 70e188612..80ee9e9cd 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -206,6 +206,26 @@ async function toHttpResponseError(response: Response): Promise<Error & { errorC return Object.assign(new Error(message), { errorCode }); } +function tagPreAcceptSendFailure(error: unknown): unknown { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + const existingCode = withCode.errorCode ?? withCode.code; + if ( + existingCode === "THREAD_BUSY" || + existingCode === "AUTH_EXPIRED" || + existingCode === "UNAUTHORIZED" || + existingCode === "RATE_LIMITED" + ) { + return Object.assign(error, { errorCode: existingCode }); + } + return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); + } + + return Object.assign(new Error("Failed to send message before stream acceptance"), { + errorCode: "SEND_FAILED_PRE_ACCEPT", + }); +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -610,6 +630,7 @@ export default function NewChatPage() { assistantMsgId, accepted, onAbort, + onPreAcceptFailure, onAcceptedStreamError, }: { error: unknown; @@ -618,6 +639,7 @@ export default function NewChatPage() { assistantMsgId: string; accepted: boolean; onAbort?: () => Promise<void>; + onPreAcceptFailure?: () => Promise<void>; onAcceptedStreamError?: () => Promise<void>; }) => { if (error instanceof Error && error.name === "AbortError") { @@ -625,12 +647,14 @@ export default function NewChatPage() { return; } - if (accepted) { + if (!accepted) { + await onPreAcceptFailure?.(); + } else { await onAcceptedStreamError?.(); } await handleChatFailure({ - error, + error: !accepted ? tagPreAcceptSendFailure(error) : error, flow, threadId, assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", @@ -863,7 +887,12 @@ export default function NewChatPage() { ); } catch (error) { console.error("[NewChatPage] Failed to create thread:", error); - toast.error("Failed to start chat. Please try again."); + await handleChatFailure({ + error: tagPreAcceptSendFailure(error), + flow: "new", + threadId: currentThreadId, + assistantMsgId: "no-persist-assistant", + }); return; } } @@ -948,27 +977,6 @@ export default function NewChatPage() { }); } - appendMessage(currentThreadId, { - role: "user", - content: persistContent, - }) - .then((savedMessage) => { - const newUserMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); - setMessageDocumentsMap((prev) => { - const docs = prev[userMsgId]; - if (!docs) return prev; - const { [userMsgId]: _, ...rest } = prev; - return { ...rest, [newUserMsgId]: docs }; - }); - if (isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - }) - .catch((err) => console.error("Failed to persist user message:", err)); - // Start streaming response setIsRunning(true); const controller = new AbortController(); @@ -988,17 +996,7 @@ export default function NewChatPage() { let wasInterrupted = false; let tokenUsageData: Record<string, unknown> | null = null; let newAccepted = false; - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); + let userPersisted = false; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1062,6 +1060,15 @@ export default function NewChatPage() { throw await toHttpResponseError(response); } newAccepted = true; + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); const flushMessages = () => { setMessages((prev) => @@ -1224,6 +1231,20 @@ export default function NewChatPage() { // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0 && !wasInterrupted) { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + await persistAssistantTurn({ threadId: currentThreadId, assistantMsgId, @@ -1251,6 +1272,20 @@ export default function NewChatPage() { assistantMsgId, accepted: newAccepted, onAbort: async () => { + if (newAccepted && !userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + // Request was cancelled by user - persist partial response if any content was received const hasContent = contentParts.some( (part) => @@ -1267,6 +1302,29 @@ export default function NewChatPage() { }); } }, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + }, + onPreAcceptFailure: async () => { + setMessages((prev) => prev.filter((m) => m.id !== userMsgId)); + setMessageDocumentsMap((prev) => { + if (!(userMsgId in prev)) return prev; + const { [userMsgId]: _removed, ...rest } = prev; + return rest; + }); + }, }); } finally { setIsRunning(false); @@ -1291,7 +1349,9 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, handleStreamTerminalError, + handleChatFailure, persistAssistantTurn, + persistUserTurn, ] ); @@ -1548,6 +1608,22 @@ export default function NewChatPage() { threadId: resumeThreadId, assistantMsgId, accepted: resumeAccepted, + onAbort: async () => { + if (!resumeAccepted) return; + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial resumed chat", + }); + }, }); } finally { setIsRunning(false); @@ -1882,6 +1958,33 @@ export default function NewChatPage() { threadId, assistantMsgId, accepted: regenerateAccepted, + onAbort: async () => { + if (!regenerateAccepted) return; + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: partialContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "partial regenerated chat", + }); + }, onAcceptedStreamError: async () => { if (!userPersisted) { const persistedUserMsgId = await persistUserTurn({ diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 4341f7dc5..57341a4c3 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -3,6 +3,7 @@ export type ChatFlow = "new" | "resume" | "regenerate"; export type ChatErrorKind = | "premium_quota_exhausted" | "thread_busy" + | "send_failed_pre_accept" | "auth_expired" | "rate_limited" | "network_offline" @@ -54,8 +55,9 @@ function getErrorMessage(error: unknown): string { function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined { if (error instanceof Error) { - const withCode = error as Error & { errorCode?: string }; + const withCode = error as Error & { errorCode?: string; code?: string }; if (withCode.errorCode) return withCode.errorCode; + if (withCode.code) return withCode.code; } if (typeof error === "object" && error !== null) { @@ -161,6 +163,20 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if (errorCode === "SEND_FAILED_PRE_ACCEPT") { + return { + kind: "send_failed_pre_accept", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "Message not sent. Please retry.", + rawMessage, + errorCode: errorCode ?? "SEND_FAILED_PRE_ACCEPT", + details: { flow: input.flow }, + }; + } + if ( errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED" @@ -196,16 +212,14 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "NETWORK_ERROR" - ) { + if (errorCode === "NETWORK_ERROR") { return { kind: "network_offline", channel: "toast", severity: "warn", telemetryEvent: "chat_error", isExpected: true, - userMessage: "Connection issue detected. Check your internet and try again.", + userMessage: "Connection issue. Please try again.", rawMessage, errorCode: errorCode ?? "NETWORK_ERROR", details: { flow: input.flow }, From e651c41372b1d0946cd63ff96c224ce6beeb7acc Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 03:13:58 -0700 Subject: [PATCH 243/299] feat: enhance tool input streaming and agent action handling for improved chat experience --- .../app/services/new_streaming_service.py | 13 +- .../app/tasks/chat/stream_new_chat.py | 256 +++++++-- .../tasks/chat/test_extract_chunk_parts.py | 43 ++ .../tasks/chat/test_tool_input_streaming.py | 527 ++++++++++++++++++ .../new-chat/[[...chat_id]]/page.tsx | 277 +++++---- .../atoms/chat/agent-actions.atom.ts | 194 ------- .../agent-action-log/action-log-sheet.tsx | 33 +- .../assistant-ui/revert-turn-button.tsx | 60 +- .../components/assistant-ui/tool-fallback.tsx | 475 ++++++++++++---- .../components/free-chat/free-chat-page.tsx | 24 +- .../contracts/types/chat-messages.types.ts | 9 +- .../hooks/use-agent-actions-query.ts | 416 ++++++++++++++ surfsense_web/hooks/use-messages-sync.ts | 8 + surfsense_web/lib/chat/streaming-state.ts | 60 +- surfsense_web/zero/schema/chat.ts | 7 + 15 files changed, 1857 insertions(+), 545 deletions(-) create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py delete mode 100644 surfsense_web/atoms/chat/agent-actions.atom.ts create mode 100644 surfsense_web/hooks/use-agent-actions-query.ts diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 5dbae91c5..3531d37af 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -595,8 +595,17 @@ class VercelStreamingService: Format the start of tool input streaming. Args: - tool_call_id: The unique tool call identifier (synthetic, derived - from LangGraph ``run_id`` so the frontend has a stable card id). + tool_call_id: The unique tool call identifier. May be EITHER the + synthetic ``call_<run_id>`` id derived from LangGraph + ``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2`` + OFF, or the unmatched-fallback path under parity_v2) OR + the authoritative LangChain ``tool_call.id`` (parity_v2 + path: when the provider streams ``tool_call_chunks`` we + register the ``index`` and reuse the lc-id as the card + id so live ``tool-input-delta`` events can be routed + without a downstream join). Either way, the same id is + preserved across ``tool-input-start`` / ``-delta`` / + ``-available`` / ``tool-output-available`` for one call. tool_name: The name of the tool being called. langchain_tool_call_id: Optional authoritative LangChain ``tool_call.id``. When set, surfaces as diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1493c4326..c94945bb1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -338,6 +338,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _legacy_match_lc_id( + pending_tool_call_chunks: list[dict[str, Any]], + tool_name: str, + run_id: str, + lc_tool_call_id_by_run: dict[str, str], +) -> str | None: + """Best-effort match a buffered ``tool_call_chunk`` to a tool name. + + Pure extract of the legacy in-line match used at ``on_tool_start`` for + parity_v2-OFF and unmatched (chunk path didn't register an index for + this call) tools. Pops the next id-bearing chunk whose ``name`` + matches ``tool_name`` (or any id-bearing chunk as a fallback) and + returns its id. Mutates ``pending_tool_call_chunks`` and + ``lc_tool_call_id_by_run`` in place. + """ + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + return None + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + return candidate + return None + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -403,10 +439,28 @@ async def _stream_agent_events( # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by # name, and pop the next unconsumed entry at ``on_tool_start``. The # authoritative id is later filled in at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. + # ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit + # this list for chunks that already registered into ``index_to_meta`` + # below — so this list is reserved for the parity_v2-OFF / unmatched + # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` + # is keyed by the chunk's ``index`` field — LangChain + # ``ToolCallChunk``s for the same call share an index but only the + # first chunk carries id+name (subsequent ones are id=None, + # name=None, args="<delta>"). We register an index when both id and + # name are observed on a chunk (per ToolCallChunk semantics they + # arrive together on the first chunk), then route every later chunk + # at that index to the same ``ui_id`` as a ``tool-input-delta``. + # ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the + # ``ui_id`` used for that call's ``tool-input-start`` so the matching + # ``tool-output-available`` (emitted from ``on_tool_end``) lands on + # the same card. + index_to_meta: dict[int, dict[str, str]] = {} + ui_tool_call_id_by_run: dict[str, str] = {} + # Per-tool-end mutable cache for the LangChain tool_call_id resolved # at ``on_tool_end``. ``_emit_tool_output`` reads this so every # ``format_tool_output_available`` call automatically carries the @@ -452,13 +506,6 @@ async def _stream_agent_events( continue parts = _extract_chunk_parts(chunk) - # Accumulate any tool_call_chunks for best-effort - # correlation with ``on_tool_start`` below. We don't emit - # anything here; the matching is done at tool-start time. - if parity_v2 and parts["tool_call_chunks"]: - for tcc in parts["tool_call_chunks"]: - pending_tool_call_chunks.append(tcc) - reasoning_delta = parts["reasoning"] text_delta = parts["text"] @@ -504,6 +551,71 @@ async def _stream_agent_events( yield streaming_service.format_text_delta(current_text_id, text_delta) accumulated_text += text_delta + # Live tool-call argument streaming. Runs AFTER text/reasoning + # processing so chunks containing both stay in their natural + # wire order (text → text-end → tool-input-start). Active + # text/reasoning are closed inside the registration branch + # before ``tool-input-start`` so the frontend sees a clean + # part boundary even when providers interleave. + if parity_v2 and parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + idx = tcc.get("index") + + # Register this index when we first see id+name + # TOGETHER. Per LangChain ToolCallChunk semantics the + # first chunk for a tool call carries both fields + # together; later chunks have id=None, name=None and + # only ``args``. Requiring BOTH keeps wire + # ``tool-input-start`` always carrying a real + # toolName (assistant-ui's typed tool-part dispatch + # keys off it). + if idx is not None and idx not in index_to_meta: + lc_id = tcc.get("id") + name = tcc.get("name") + if lc_id and name: + ui_id = lc_id + + # Close active text/reasoning so wire + # ordering stays clean even on providers + # that interleave text and tool-call chunks + # within the same stream window. + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + if current_reasoning_id is not None: + yield streaming_service.format_reasoning_end( + current_reasoning_id + ) + current_reasoning_id = None + + index_to_meta[idx] = { + "ui_id": ui_id, + "lc_id": lc_id, + "name": name, + } + yield streaming_service.format_tool_input_start( + ui_id, + name, + langchain_tool_call_id=lc_id, + ) + + # Emit args delta for any chunk at a registered + # index (including idless continuations). Once an + # index is owned by ``index_to_meta`` we DO NOT + # append to ``pending_tool_call_chunks`` — that list + # is reserved for the parity_v2-OFF / unmatched + # fallback path so it never re-pops chunks already + # consumed here (skip-append). + meta = index_to_meta.get(idx) if idx is not None else None + if meta: + args_chunk = tcc.get("args") or "" + if args_chunk: + yield streaming_service.format_tool_input_delta( + meta["ui_id"], args_chunk + ) + else: + pending_tool_call_chunks.append(tcc) + elif event_type == "on_tool_start": active_tool_depth += 1 tool_name = event.get("name", "unknown_tool") @@ -834,44 +946,65 @@ async def _stream_agent_events( status="in_progress", ) - tool_call_id = ( - f"call_{run_id[:32]}" - if run_id - else streaming_service.generate_tool_call_id() - ) - - # Best-effort attach the LangChain ``tool_call_id``. We - # pop the first chunk in ``pending_tool_call_chunks`` whose - # name matches; if none match (the chunked args may not yet - # carry a ``name`` field, or the model skipped the chunked - # form) we leave ``langchainToolCallId`` unset for now and - # fill it in authoritatively at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. - langchain_tool_call_id: str | None = None - if parity_v2 and pending_tool_call_chunks: - matched_idx: int | None = None - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("name") == tool_name and tcc.get("id"): - matched_idx = idx + # Resolve the card identity. If the chunk-emission loop + # already registered an ``index`` for this tool call (parity_v2 + # path), reuse the same ui_id so the card sees: + # tool-input-start → deltas… → tool-input-available → + # tool-output-available all keyed by lc_id. Otherwise fall + # back to the synthetic ``call_<run_id>`` id and the legacy + # best-effort match against ``pending_tool_call_chunks``. + matched_meta: dict[str, str] | None = None + if parity_v2: + # FIFO over indices 0,1,2…; first unassigned same-name + # match wins. Handles parallel same-name calls (e.g. two + # write_file calls) deterministically as long as the + # model interleaves on_tool_start in the same order it + # streamed the args. + taken_ui_ids = set(ui_tool_call_id_by_run.values()) + for meta in index_to_meta.values(): + if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: + matched_meta = meta break - if matched_idx is None: - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("id"): - matched_idx = idx - break - if matched_idx is not None: - matched = pending_tool_call_chunks.pop(matched_idx) - candidate = matched.get("id") - if isinstance(candidate, str) and candidate: - langchain_tool_call_id = candidate - if run_id: - lc_tool_call_id_by_run[run_id] = candidate - yield streaming_service.format_tool_input_start( - tool_call_id, - tool_name, - langchain_tool_call_id=langchain_tool_call_id, - ) + tool_call_id: str + langchain_tool_call_id: str | None = None + if matched_meta is not None: + tool_call_id = matched_meta["ui_id"] + langchain_tool_call_id = matched_meta["lc_id"] + # ``tool-input-start`` already fired during chunk + # emission — skip the duplicate. No pruning is needed + # because the chunk-emission loop intentionally never + # appends registered-index chunks to + # ``pending_tool_call_chunks`` (skip-append). + if run_id: + lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] + else: + tool_call_id = ( + f"call_{run_id[:32]}" + if run_id + else streaming_service.generate_tool_call_id() + ) + # Legacy fallback: parity_v2 OFF, or parity_v2 ON but the + # provider didn't stream tool_call_chunks for this call + # (no index registered). Run the existing best-effort + # match BEFORE emitting start so we still attach an + # authoritative ``langchainToolCallId`` when possible. + if parity_v2: + langchain_tool_call_id = _legacy_match_lc_id( + pending_tool_call_chunks, + tool_name, + run_id, + lc_tool_call_id_by_run, + ) + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + ) + + if run_id: + ui_tool_call_id_by_run[run_id] = tool_call_id + # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): @@ -924,7 +1057,15 @@ async def _stream_agent_events( result.write_succeeded = True result.verification_succeeded = True - tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" + # Look up the SAME card id used at on_tool_start (either the + # parity_v2 lc-id-derived ui_id or the legacy synthetic + # ``call_<run_id>``) so the output event always lands on the + # same card as start/delta/available. Fallback preserves the + # legacy synthetic shape for parity_v2-OFF / unknown-run paths. + tool_call_id = ui_tool_call_id_by_run.get( + run_id, + f"call_{run_id[:32]}" if run_id else "call_unknown", + ) original_step_id = tool_step_ids.get( run_id, f"{step_prefix}-unknown-{run_id[:8]}" ) @@ -935,17 +1076,22 @@ async def _stream_agent_events( # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) # if the output isn't a ToolMessage. The value is stored in # ``current_lc_tool_call_id`` so ``_emit_tool_output`` - # picks it up for every output emit below. Stays None when - # parity_v2 is off so legacy emit paths are untouched. + # picks it up for every output emit below. + # + # Emitted in BOTH parity_v2 and legacy modes: the chat tool + # card needs the LangChain id to match against the + # ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``) + # so the inline Revert button can light up. Reading + # ``raw_output.tool_call_id`` is a cheap, non-mutating attribute + # access that is safe regardless of feature-flag state. current_lc_tool_call_id["value"] = None - if parity_v2: - authoritative = getattr(raw_output, "tool_call_id", None) - if isinstance(authoritative, str) and authoritative: - current_lc_tool_call_id["value"] = authoritative - if run_id: - lc_tool_call_id_by_run[run_id] = authoritative - elif run_id and run_id in lc_tool_call_id_by_run: - current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + current_lc_tool_call_id["value"] = authoritative + if run_id: + lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in lc_tool_call_id_by_run: + current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] if tool_name == "read_file": yield streaming_service.format_thinking_step( diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py index 7f32bf456..1263a5fe1 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -183,3 +183,46 @@ class TestDefensive: assert out["text"] == "" assert out["reasoning"] == "" assert out["tool_call_chunks"] == [] + + +class TestIdlessContinuationChunks: + """Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a + tool call carries id+name; later chunks for the same call have + ``id=None, name=None`` and only ``args`` + ``index``. Live tool-call + argument streaming relies on those idless continuation chunks + flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream + chunk-emission loop can still route them by ``index``. + """ + + def test_idless_continuation_chunk_preserved_verbatim(self) -> None: + chunk = _FakeChunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ] + ) + out = _extract_chunk_parts(chunk) + assert len(out["tool_call_chunks"]) == 1 + tcc = out["tool_call_chunks"][0] + assert tcc.get("id") is None + assert tcc.get("name") is None + assert tcc.get("args") == '_path":"/x"}' + assert tcc.get("index") == 0 + + def test_first_then_idless_sequence_preserves_index(self) -> None: + """Both chunks for the same call share an ``index`` key — the + index-routing loop in ``stream_new_chat`` depends on it.""" + first = _FakeChunk( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0} + ] + ) + cont = _FakeChunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ] + ) + out_first = _extract_chunk_parts(first) + out_cont = _extract_chunk_parts(cont) + assert out_first["tool_call_chunks"][0]["index"] == 0 + assert out_cont["tool_call_chunks"][0]["index"] == 0 + assert out_cont["tool_call_chunks"][0].get("id") is None diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py new file mode 100644 index 000000000..9258d5cfe --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -0,0 +1,527 @@ +"""Unit tests for live tool-call argument streaming. + +Pins the wire format that ``_stream_agent_events`` emits when +``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` → +``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available`` +all keyed by the same LangChain ``tool_call.id``. + +Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and +``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to +``_stream_agent_events`` so we exercise them via the public wire output. + +These tests also lock in the legacy / parity_v2-OFF behaviour so the +synthetic ``call_<run_id>`` shape stays stable for older clients. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any + +import pytest + +import app.tasks.chat.stream_new_chat as stream_module +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.stream_new_chat import ( + StreamResult, + _legacy_match_lc_id, + _stream_agent_events, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeChunk: + """Minimal stand-in for ``AIMessageChunk``.""" + + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class _FakeToolMessage: + """Stand-in for ``ToolMessage`` returned by ``on_tool_end``.""" + + content: Any + tool_call_id: str | None = None + + +class _FakeAgentState: + """Stand-in for ``StateSnapshot`` returned by ``aget_state``.""" + + def __init__(self) -> None: + # Empty values keeps the cloud-fallback safety-net branch a no-op, + # and an empty ``tasks`` list keeps the post-stream interrupt + # check a no-op too. + self.values: dict[str, Any] = {} + self.tasks: list[Any] = [] + + +class _FakeAgent: + """Replays a list of ``astream_events`` events.""" + + def __init__(self, events: list[dict[str, Any]]) -> None: + self._events = events + + async def astream_events( # type: ignore[no-untyped-def] + self, _input_data: Any, *, config: dict[str, Any], version: str + ) -> AsyncGenerator[dict[str, Any], None]: + del config, version # unused, contract-compatible + for ev in self._events: + yield ev + + async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState: + # Called once after astream_events drains so the cloud-fallback + # safety net can inspect staged filesystem work. The fake stays + # empty so the safety net is a no-op. + return _FakeAgentState() + + +def _model_stream( + *, + text: str = "", + reasoning: str = "", + tool_call_chunks: list[dict[str, Any]] | None = None, + tags: list[str] | None = None, +) -> dict[str, Any]: + return ( + { + "event": "on_chat_model_stream", + "tags": tags or [], + "data": { + "chunk": _FakeChunk( + content=text, + tool_call_chunks=list(tool_call_chunks or []), + ) + }, + # reasoning piggybacks via additional_kwargs path; if needed, + # override content to a typed-block list. Most tests just check + # tool_call_chunks routing so this is fine. + } + if not reasoning + else { + "event": "on_chat_model_stream", + "tags": tags or [], + "data": { + "chunk": _FakeChunk( + content=text, + additional_kwargs={"reasoning_content": reasoning}, + tool_call_chunks=list(tool_call_chunks or []), + ) + }, + } + ) + + +def _tool_start( + *, + name: str, + run_id: str, + input_payload: dict[str, Any] | None = None, +) -> dict[str, Any]: + return { + "event": "on_tool_start", + "name": name, + "run_id": run_id, + "data": {"input": input_payload or {}}, + } + + +def _tool_end( + *, + name: str, + run_id: str, + tool_call_id: str | None = None, + output: Any = "ok", +) -> dict[str, Any]: + return { + "event": "on_tool_end", + "name": name, + "run_id": run_id, + "data": { + "output": _FakeToolMessage( + content=json.dumps(output) if not isinstance(output, str) else output, + tool_call_id=tool_call_id, + ) + }, + } + + +@pytest.fixture +def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stream_module, + "get_flags", + lambda: AgentFeatureFlags(enable_stream_parity_v2=True), + ) + + +@pytest.fixture +def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stream_module, + "get_flags", + lambda: AgentFeatureFlags(enable_stream_parity_v2=False), + ) + + +async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Run ``_stream_agent_events`` against a fake agent and return the + SSE payloads (parsed JSON) it yielded. + """ + agent = _FakeAgent(events) + service = VercelStreamingService() + result = StreamResult() + config = {"configurable": {"thread_id": "test-thread"}} + sse_lines: list[str] = [] + async for sse in _stream_agent_events( + agent, config, {}, service, result, step_prefix="thinking" + ): + sse_lines.append(sse) + + parsed: list[dict[str, Any]] = [] + for line in sse_lines: + if not line.startswith("data: "): + continue + body = line[len("data: ") :].rstrip("\n") + if not body or body == "[DONE]": + continue + try: + parsed.append(json.loads(body)) + except json.JSONDecodeError: + continue + return parsed + + +def _types(payloads: list[dict[str, Any]]) -> list[str]: + return [p.get("type", "?") for p in payloads] + + +def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]: + return [p for p in payloads if p.get("type") == type_name] + + +# --------------------------------------------------------------------------- +# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour. +# --------------------------------------------------------------------------- + + +class TestLegacyMatch: + def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None: + chunks: list[dict[str, Any]] = [ + {"id": "x1", "name": "ls"}, + {"id": "y1", "name": "write_file"}, + ] + runs: dict[str, str] = {} + result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs) + assert result == "y1" + assert chunks == [{"id": "x1", "name": "ls"}] + assert runs == {"run-1": "y1"} + + def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None: + chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}] + runs: dict[str, str] = {} + out = _legacy_match_lc_id(chunks, "ls", "run-2", runs) + assert out == "anon" + assert chunks == [] + + def test_returns_none_when_no_id_bearing_chunk(self) -> None: + chunks: list[dict[str, Any]] = [{"id": None, "name": None}] + runs: dict[str, str] = {} + assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None + assert chunks == [{"id": None, "name": None}] + assert runs == {} + + +# --------------------------------------------------------------------------- +# parity_v2 wire format tests. +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: + """First chunk carries id+name; later idless chunks at the same + ``index`` merge into the SAME ``tool-input-start`` ui id and emit + one ``tool-input-delta`` per chunk.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0} + ], + ), + _model_stream( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ], + ), + _tool_start( + name="write_file", run_id="run-A", input_payload={"file_path": "/x"} + ), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + deltas = _of_type(payloads, "tool-input-delta") + available = _of_type(payloads, "tool-input-available") + output = _of_type(payloads, "tool-output-available") + + assert len(starts) == 1 + assert starts[0]["toolCallId"] == "lc-1" + assert starts[0]["toolName"] == "write_file" + assert starts[0]["langchainToolCallId"] == "lc-1" + + assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}'] + assert all(d["toolCallId"] == "lc-1" for d in deltas) + + assert len(available) == 1 + assert available[0]["toolCallId"] == "lc-1" + + assert len(output) == 1 + assert output[0]["toolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_two_interleaved_tool_calls_route_by_index( + parity_v2_on: None, +) -> None: + """Two same-name calls with distinct indices keep their deltas + routed to the right card.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0}, + {"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1}, + ] + ), + _model_stream( + tool_call_chunks=[ + {"id": None, "name": None, "args": "}", "index": 0}, + {"id": None, "name": None, "args": "}", "index": 1}, + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"), + _tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}), + _tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"), + ] + + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + deltas = _of_type(payloads, "tool-input-delta") + output = _of_type(payloads, "tool-output-available") + + assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"} + + by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []} + for d in deltas: + by_id[d["toolCallId"]].append(d["inputTextDelta"]) + assert by_id["lc-A"] == ['{"a":1', "}"] + assert by_id["lc-B"] == ['{"b":2', "}"] + + assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"} + + +@pytest.mark.asyncio +async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: + """Whatever id ``tool-input-start`` chose must be the SAME id used + on ``tool-input-available`` AND ``tool-output-available``.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0} + ] + ), + _tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"), + ] + payloads = await _drain(events) + relevant = [ + p + for p in payloads + if p.get("type") + in {"tool-input-start", "tool-input-available", "tool-output-available"} + ] + assert {p["toolCallId"] for p in relevant} == {"lc-9"} + + +@pytest.mark.asyncio +async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: + """When the chunk-emission loop already fired ``tool-input-start`` + for this run, ``on_tool_start`` MUST NOT emit a second one.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": "{}", "index": 0} + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + payloads = await _drain(events) + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_active_text_closes_before_early_tool_input_start( + parity_v2_on: None, +) -> None: + """Streaming a text-delta then a tool-call chunk in subsequent + chunks: the wire MUST contain ``text-end`` before the FIRST + ``tool-input-start`` (clean part boundary on the frontend).""" + events = [ + _model_stream(text="Working on it"), + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": "{}", "index": 0} + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + types = _types(await _drain(events)) + text_end_idx = types.index("text-end") + start_idx = types.index("tool-input-start") + assert text_end_idx < start_idx + + +@pytest.mark.asyncio +async def test_mixed_text_and_tool_chunk_preserve_order( + parity_v2_on: None, +) -> None: + """One AIMessageChunk that carries BOTH ``text`` content AND + ``tool_call_chunks`` should emit the text delta FIRST, then close + text, then ``tool-input-start``+``tool-input-delta``.""" + events = [ + _model_stream( + text="I'll update it", + tool_call_chunks=[ + { + "id": "lc-1", + "name": "write_file", + "args": '{"file_path":"/x"}', + "index": 0, + } + ], + ), + _tool_start( + name="write_file", run_id="run-A", input_payload={"file_path": "/x"} + ), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + types = _types(await _drain(events)) + # text-start … text-delta … text-end … tool-input-start … tool-input-delta + assert types.index("text-start") < types.index("text-delta") + assert types.index("text-delta") < types.index("text-end") + assert types.index("text-end") < types.index("tool-input-start") + assert types.index("tool-input-start") < types.index("tool-input-delta") + + +@pytest.mark.asyncio +async def test_parity_v2_off_preserves_legacy_shape( + parity_v2_off: None, +) -> None: + """When the flag is OFF, no deltas are emitted and the ``toolCallId`` + is ``call_<run_id>`` (NOT the lc id).""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0} + ] + ), + _tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"), + ] + payloads = await _drain(events) + + assert _of_type(payloads, "tool-input-delta") == [] + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"].startswith("call_run-A") + # No ``langchainToolCallId`` propagation on ``tool-input-start`` in + # legacy mode (the start event fires before the ToolMessage is + # available, so we can't extract the authoritative LangChain id yet). + assert "langchainToolCallId" not in starts[0] + output = _of_type(payloads, "tool-output-available") + assert output[0]["toolCallId"].startswith("call_run-A") + # ``tool-output-available`` MUST carry ``langchainToolCallId`` even + # in legacy mode: the chat tool card uses it to backfill the + # LangChain id and join against the ``data-action-log`` SSE event + # (keyed by ``lc_tool_call_id``) so the inline Revert button can + # light up. Sourced from the returned ``ToolMessage.tool_call_id``, + # which is populated regardless of feature-flag state. + assert output[0]["langchainToolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_skip_append_prevents_stale_id_reuse( + parity_v2_on: None, +) -> None: + """Two same-name tools: the SECOND tool's ``langchainToolCallId`` + must NOT come from the first tool's chunk (``pending_tool_call_chunks`` + must stay empty for indexed-registered chunks).""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-A", "name": "write_file", "args": "{}", "index": 0}, + {"id": "lc-B", "name": "write_file", "args": "{}", "index": 1}, + ] + ), + _tool_start(name="write_file", run_id="run-1", input_payload={}), + _tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"), + _tool_start(name="write_file", run_id="run-2", input_payload={}), + _tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"), + ] + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + # Two distinct lc ids, each its own card. + assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"} + # Each tool-output-available landed on its respective card. + output = _of_type(payloads, "tool-output-available") + assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"} + + +@pytest.mark.asyncio +async def test_registration_waits_for_both_id_and_name( + parity_v2_on: None, +) -> None: + """An id-only chunk (no name yet) must NOT emit ``tool-input-start``.""" + events = [ + _model_stream( + tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}] + ), + ] + payloads = await _drain(events) + assert _of_type(payloads, "tool-input-start") == [] + + +@pytest.mark.asyncio +async def test_unmatched_fallback_still_attaches_lc_id( + parity_v2_on: None, +) -> None: + """parity_v2 ON, but the provider didn't include an ``index``: the + legacy fallback path must still emit ``tool-input-start`` with the + matching ``langchainToolCallId``.""" + events = [ + # No index on the chunk → not registered into index_to_meta; + # falls through to ``pending_tool_call_chunks`` so the legacy + # match path can pop it at on_tool_start. + _model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]), + _tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"), + ] + payloads = await _drain(events) + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"].startswith("call_run-1") + assert starts[0]["langchainToolCallId"] == "lc-orphan" diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index c2086e80a..e5ac61cd9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -14,13 +14,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; -import { - agentActionsByChatTurnIdAtom, - markAgentActionRevertedAtom, - resetAgentActionMapAtom, - updateAgentActionReversibleAtom, - upsertAgentActionAtom, -} from "@/atoms/chat/agent-actions.atom"; import { clearTargetCommentIdAtom, currentThreadAtom, @@ -55,6 +48,12 @@ import { type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; +import { + applyActionLogSse, + applyActionLogUpdatedSse, + markActionRevertedInCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; @@ -71,12 +70,12 @@ import { addToolCall, appendReasoning, appendText, + appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, endReasoning, FrameBatchedUpdater, - findToolCallIdByLcId, readSSEStream, type ThinkingStepData, type ToolUIGate, @@ -246,14 +245,6 @@ export default function NewChatPage() { const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); - // Agent action log SSE side-channel. - const upsertAgentAction = useSetAtom(upsertAgentActionAtom); - const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom); - const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom); - const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom); - // Chat-turn-keyed action map for the edit-from-position pre-flight - // that decides whether to show the confirmation dialog. - const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom); // Edit dialog state. Holds the message id being edited and // the (already extracted) regenerate args so we can resume the edit // after the user picks "revert all" / "continue" / "cancel". @@ -282,6 +273,11 @@ export default function NewChatPage() { content: unknown; author_id: string | null; created_at: string; + // Forwarded so ``convertToThreadMessage`` can rebuild the + // ``metadata.custom.chatTurnId`` on the + // ``ThreadMessageLike``. Required by the inline Revert + // button's per-turn fallback. + turn_id?: string | null; }[] ) => { if (isRunning) { @@ -314,6 +310,11 @@ export default function NewChatPage() { created_at: msg.created_at, author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null, author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null, + // Forward the per-turn correlation id so the + // inline Revert button's ``(chat_turn_id, + // tool_name, position)`` fallback survives the + // post-stream Zero re-sync. + turn_id: msg.turn_id ?? null, }); }); }); @@ -330,6 +331,13 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.search_space_id]); + // Unified store for agent-action rows (the same react-query cache + // the agent-actions sheet, the inline Revert button, and the + // per-turn Revert button all read). Hydrates from + // ``GET /threads/{id}/actions`` and is updated incrementally by the + // SSE handlers + revert-batch results below — no atom side-channel. + const { items: agentActionItems } = useAgentActionsQuery(threadId); + // Extract chat_id from URL params const urlChatId = useMemo(() => { const id = params.chat_id; @@ -357,7 +365,8 @@ export default function NewChatPage() { clearPlanOwnerRegistry(); closeReportPanel(); closeEditorPanel(); - resetAgentActionMap(); + // Note: agent-action data is keyed by threadId in react-query so + // switching threads naturally swaps caches; no explicit reset. try { if (urlChatId > 0) { @@ -426,7 +435,6 @@ export default function NewChatPage() { removeChatTab, searchSpaceId, tokenUsageStore, - resetAgentActionMap, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) @@ -779,6 +787,15 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text + // streamed). ``scheduleFlush(); batcher.flush()`` sets + // the dirty bit FIRST so terminal events render + // promptly without the 50ms throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -815,13 +832,23 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); + break; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens + // of times per call, so use throttled + // scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); break; case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -834,8 +861,14 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + // addToolCall doesn't accept argsText today; + // backfill via updateToolCall so the new card + // renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; } @@ -854,7 +887,7 @@ export default function NewChatPage() { } } } - batcher.flush(); + forceFlush(); break; } @@ -950,34 +983,17 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId: currentThreadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + applyActionLogSse(queryClient, currentThreadId, searchSpaceId, parsed.data); break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + applyActionLogUpdatedSse( + queryClient, + currentThreadId, + parsed.data.id, + parsed.data.reversible + ); break; } @@ -1179,6 +1195,15 @@ export default function NewChatPage() { toolName: String(p.toolName), args: (p.args as Record<string, unknown>) ?? {}, result: p.result as unknown, + // Restore argsText so persisted pretty-printed + // JSON survives reloads (assistant-ui prefers + // supplied argsText over JSON.stringify(args)). + // langchainToolCallId restoration also fixes a + // pre-existing dropped-id bug on resume. + ...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}), + ...(typeof p.langchainToolCallId === "string" + ? { langchainToolCallId: p.langchainToolCallId } + : {}), }); contentPartsState.currentTextPartIndex = -1; } else if (p.type === "data-thinking-steps") { @@ -1200,7 +1225,12 @@ export default function NewChatPage() { const editedAction = decisions[0].edited_action; for (const part of contentParts) { if (part.type === "tool-call" && part.toolName === editedAction.name) { - part.args = { ...part.args, ...editedAction.args }; + const mergedArgs = { ...part.args, ...editedAction.args }; + part.args = mergedArgs; + // Sync argsText so the rendered card shows the + // edited inputs — assistant-ui prefers caller- + // supplied argsText over JSON.stringify(args). + part.argsText = JSON.stringify(mergedArgs, null, 2); break; } } @@ -1256,6 +1286,10 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -1292,13 +1326,20 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -1311,9 +1352,13 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { @@ -1321,7 +1366,7 @@ export default function NewChatPage() { langchainToolCallId: parsed.langchainToolCallId, }); markInterruptsCompleted(contentParts); - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { @@ -1381,34 +1426,17 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId: resumeThreadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + applyActionLogSse(queryClient, resumeThreadId, searchSpaceId, parsed.data); break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + applyActionLogUpdatedSse( + queryClient, + resumeThreadId, + parsed.data.id, + parsed.data.reversible + ); break; } @@ -1502,6 +1530,11 @@ export default function NewChatPage() { return { ...part, args: decision.edited_action.args, // Update displayed args + // Sync argsText so the rendered card shows + // the edited inputs — assistant-ui prefers + // caller-supplied argsText over + // JSON.stringify(args). + argsText: JSON.stringify(decision.edited_action.args, null, 2), result: { ...(part.result as Record<string, unknown>), __decided__: decisionType, @@ -1712,6 +1745,10 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -1748,13 +1785,20 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -1767,9 +1811,13 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { @@ -1786,7 +1834,7 @@ export default function NewChatPage() { } } } - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { @@ -1802,34 +1850,21 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + if (threadId !== null) { + applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); + } break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + if (threadId !== null) { + applyActionLogUpdatedSse( + queryClient, + threadId, + parsed.data.id, + parsed.data.reversible + ); + } break; } @@ -1866,12 +1901,16 @@ export default function NewChatPage() { : `Reverted ${summary.reverted} downstream actions before regenerating.` ); } - for (const r of summary.results) { - if (r.status === "reverted" || r.status === "already_reverted") { - markAgentActionReverted({ - id: r.action_id, - newActionId: r.new_action_id ?? null, - }); + if (threadId !== null) { + for (const r of summary.results) { + if (r.status === "reverted" || r.status === "already_reverted") { + markActionRevertedInCache( + queryClient, + threadId, + r.action_id, + r.new_action_id ?? null + ); + } } } break; @@ -2019,16 +2058,26 @@ export default function NewChatPage() { const downstream = messages.slice(editedIndex + 1); downstreamTotalCount = downstream.length; const seenTurns = new Set<string>(); + const downstreamTurnIds = new Set<string>(); for (const m of downstream) { const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; const tid = meta.custom?.chatTurnId; if (!tid || seenTurns.has(tid)) continue; seenTurns.add(tid); - const turnActions = agentActionsByChatTurnId.get(tid) ?? []; - for (const a of turnActions) { - if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { - downstreamReversibleCount += 1; - } + downstreamTurnIds.add(tid); + } + // Source of truth: the unified react-query cache. Every + // action whose ``chat_turn_id`` belongs to the slice we're + // about to drop counts toward the prompt. + for (const a of agentActionItems) { + if (!a.chat_turn_id || !downstreamTurnIds.has(a.chat_turn_id)) continue; + if ( + a.reversible && + (a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) && + !a.is_revert_action && + (a.error === null || a.error === undefined) + ) { + downstreamReversibleCount += 1; } } } @@ -2052,7 +2101,7 @@ export default function NewChatPage() { downstreamTotalCount, }); }, - [handleRegenerate, messages, agentActionsByChatTurnId] + [handleRegenerate, messages, agentActionItems] ); const handleEditDialogChoice = useCallback( diff --git a/surfsense_web/atoms/chat/agent-actions.atom.ts b/surfsense_web/atoms/chat/agent-actions.atom.ts deleted file mode 100644 index 7830c8751..000000000 --- a/surfsense_web/atoms/chat/agent-actions.atom.ts +++ /dev/null @@ -1,194 +0,0 @@ -"use client"; - -import { atom } from "jotai"; - -/** - * Minimal per-row projection of ``AgentActionLog`` that the tool card - * needs to decide whether to render a Revert button. - * - * Fields are deliberately a subset of the full ``AgentAction`` so the - * SSE side-channel (``data-action-log`` / ``data-action-log-updated``) - * can populate them without depending on the REST endpoint - * ``GET /threads/.../actions`` (which 503s when - * ``SURFSENSE_ENABLE_ACTION_LOG`` is off). - */ -export interface AgentActionLite { - id: number; - threadId: number | null; - lcToolCallId: string | null; - chatTurnId: string | null; - toolName: string; - reversible: boolean; - reverseDescriptorPresent: boolean; - error: boolean; - revertedByActionId: number | null; - isRevertAction: boolean; - createdAt: string | null; -} - -/** - * Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart - * tool-call.langchainToolCallId``). - */ -export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map()); - -/** - * Parallel map keyed off the synthetic chat-card ``toolCallId`` - * (``call_<run-id>``) so ``ToolFallback`` (which only receives the - * synthetic id from assistant-ui) can join its card to the action log. - * - * Both maps are kept in sync by ``upsertAgentActionAtom``. - */ -export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map()); - -/** - * Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer - * "how many reversible actions does this assistant turn contain?" in - * O(1). Each entry's array is ordered by insertion (which - * for a single turn matches ``created_at`` because action-log writes - * happen synchronously). - */ -export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map()); - -/** - * Action to upsert one ``AgentActionLite`` row. - * - * ``toolCallId`` is the synthetic card id (``call_<run-id>`` from - * ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the - * action is indexed under BOTH ids so the tool card can perform the - * lookup without going via the streaming state. - */ -export const upsertAgentActionAtom = atom( - null, - (_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => { - const { action, toolCallId } = payload; - const upsertInto = ( - prev: Map<string, AgentActionLite>, - key: string - ): Map<string, AgentActionLite> => { - const next = new Map(prev); - const existing = next.get(key); - next.set(key, { - ...action, - // Preserve the local "reverted" bookkeeping if a reversibility - // flip arrives AFTER the user already reverted via the REST - // route. We never want a stale ``reversible=true`` event to - // resurrect a Reverted card. - revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId, - isRevertAction: existing?.isRevertAction ?? action.isRevertAction, - }); - return next; - }; - if (action.lcToolCallId) { - set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string)); - } - if (toolCallId) { - set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId)); - } - if (action.chatTurnId) { - set(agentActionsByChatTurnIdAtom, (prev) => { - const next = new Map(prev); - const turnId = action.chatTurnId as string; - const existing = next.get(turnId) ?? []; - const priorEntry = existing.find((row) => row.id === action.id); - const merged: AgentActionLite = { - ...action, - revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId, - isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction, - }; - const others = existing.filter((row) => row.id !== action.id); - next.set(turnId, [...others, merged]); - return next; - }); - } - } -); - -function mutateById( - prev: Map<string, AgentActionLite>, - id: number, - mutator: (entry: AgentActionLite) => AgentActionLite -): Map<string, AgentActionLite> { - let mutated = false; - const next = new Map(prev); - for (const [key, value] of next) { - if (value.id === id) { - next.set(key, mutator(value)); - mutated = true; - } - } - return mutated ? next : prev; -} - -function mutateByIdInTurnIndex( - prev: Map<string, AgentActionLite[]>, - id: number, - mutator: (entry: AgentActionLite) => AgentActionLite -): Map<string, AgentActionLite[]> { - let mutated = false; - const next = new Map(prev); - for (const [key, list] of next) { - let listMutated = false; - const updated = list.map((row) => { - if (row.id === id) { - listMutated = true; - return mutator(row); - } - return row; - }); - if (listMutated) { - next.set(key, updated); - mutated = true; - } - } - return mutated ? next : prev; -} - -/** - * Action to flip an existing entry's ``reversible`` flag, keyed by the - * AgentActionLog row id (the SSE ``data-action-log-updated`` payload - * does NOT carry ``lcToolCallId``). - */ -export const updateAgentActionReversibleAtom = atom( - null, - (_get, set, payload: { id: number; reversible: boolean }) => { - const apply = (entry: AgentActionLite): AgentActionLite => ({ - ...entry, - reversible: payload.reversible, - }); - set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); - } -); - -/** Action to mark an existing entry as reverted (post-revert call). */ -export const markAgentActionRevertedAtom = atom( - null, - (_get, set, payload: { id: number; newActionId: number | null }) => { - const apply = (entry: AgentActionLite): AgentActionLite => ({ - ...entry, - revertedByActionId: payload.newActionId ?? -1, - }); - set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); - } -); - -/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */ -export const markAgentActionsRevertedBatchAtom = atom( - null, - (_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => { - for (const entry of payload.entries) { - set(markAgentActionRevertedAtom, entry); - } - } -); - -/** Reset all maps (e.g. when the active thread changes). */ -export const resetAgentActionMapAtom = atom(null, (_get, set) => { - set(agentActionByLcIdAtom, new Map()); - set(agentActionByToolCallIdAtom, new Map()); - set(agentActionsByChatTurnIdAtom, new Map()); -}); diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx index 68d2ffef3..32c25771a 100644 --- a/surfsense_web/components/agent-action-log/action-log-sheet.tsx +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -1,9 +1,9 @@ "use client"; -import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; import { Activity, RefreshCcw } from "lucide-react"; -import { useCallback, useMemo } from "react"; +import { useCallback } from "react"; import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { Badge } from "@/components/ui/badge"; @@ -17,15 +17,12 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Skeleton } from "@/components/ui/skeleton"; -import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { + agentActionsQueryKey, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { ActionLogItem } from "./action-log-item"; -const ACTION_LOG_PAGE_SIZE = 50; - -function actionLogQueryKey(threadId: number) { - return ["agent-actions", threadId] as const; -} - function EmptyState() { return ( <div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center"> @@ -85,25 +82,17 @@ export function ActionLogSheet() { const threadId = state.threadId; - const { data, isLoading, isFetching, isError, error, refetch } = useQuery({ - queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"], - queryFn: () => - agentActionsApiService.listForThread(threadId as number, { - page: 0, - pageSize: ACTION_LOG_PAGE_SIZE, - }), - enabled: state.open && threadId !== null && actionLogEnabled, - staleTime: 15 * 1000, - }); + const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery( + threadId, + { enabled: state.open && actionLogEnabled } + ); const handleRevertSuccess = useCallback(() => { if (threadId !== null) { - queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) }); + queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) }); } }, [queryClient, threadId]); - const items = useMemo(() => data?.items ?? [], [data]); - return ( <Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}> <SheetContent diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx index af71299d0..733162c80 100644 --- a/surfsense_web/components/assistant-ui/revert-turn-button.tsx +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -4,26 +4,22 @@ * "Revert turn" button rendered at the bottom of every completed * assistant turn that has at least one reversible action. * - * The button reads the action map keyed by ``chat_turn_id`` from the - * SSE side-channel (``data-action-log`` events). It shows a confirmation - * dialog summarising "N reversible / M total" and, on confirm, calls - * ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + * The button reads from the unified ``useAgentActionsQuery`` cache + * (the SAME react-query cache the agent-actions sheet and the inline + * Revert button consume) filtered by ``chat_turn_id``. It shows a + * confirmation dialog summarising "N reversible / M total" and, on + * confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``. * * The route returns a per-action result list and never collapses the * batch into a 4xx — so we render any failed/not_reversible rows inline * with their messages. */ -import { useAtomValue, useSetAtom } from "jotai"; -import { selectAtom } from "jotai/utils"; +import { useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useMemo, useState } from "react"; import { toast } from "sonner"; -import { - type AgentActionLite, - agentActionsByChatTurnIdAtom, - markAgentActionsRevertedBatchAtom, -} from "@/atoms/chat/agent-actions.atom"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { AlertDialog, @@ -38,6 +34,10 @@ import { } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; import { getToolDisplayName } from "@/contracts/enums/toolIcons"; +import { + applyRevertTurnResultsToCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { agentActionsApiService, type RevertTurnActionResult, @@ -49,49 +49,33 @@ interface RevertTurnButtonProps { chatTurnId: string | null | undefined; } -// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a -// stable reference when the turn has no recorded actions yet. Without -// this every render allocates a fresh ``[]`` and Jotai's -// equality check would re-render the button on unrelated turn updates. -const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]); - export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { const session = useAtomValue(chatSessionStateAtom); - const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom); + const threadId = session?.threadId ?? null; + const queryClient = useQueryClient(); + const { findByChatTurnId } = useAgentActionsQuery(threadId); const [isReverting, setIsReverting] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false); const [resultsOpen, setResultsOpen] = useState(false); const [results, setResults] = useState<RevertTurnActionResult[]>([]); - // Subscribe ONLY to the slice of the global action map that belongs - // to ``chatTurnId``. Previously the button read the whole - // ``agentActionsByChatTurnIdAtom``, which meant every action - // upsert (one per tool call) re-rendered every Revert button on - // the page. With ``selectAtom`` we re-render only when our turn's - // list reference changes — and the upsert/mark atoms produce a - // fresh list reference for the affected turn only. - const sliceAtom = useMemo( - () => - selectAtom( - agentActionsByChatTurnIdAtom, - (turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS - ), - [chatTurnId] - ); - const actions = useAtomValue(sliceAtom); + const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]); const reversibleCount = useMemo( () => actions.filter( - (a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error + (a) => + a.reversible && + (a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) && + !a.is_revert_action && + (a.error === null || a.error === undefined) ).length, [actions] ); - const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); + const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]); if (!chatTurnId) return null; if (reversibleCount === 0) return null; - const threadId = session?.threadId; if (!threadId) return null; const handleRevertTurn = async () => { @@ -103,7 +87,7 @@ export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { .filter((r) => r.status === "reverted" || r.status === "already_reverted") .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); if (revertedEntries.length > 0) { - markRevertedBatch({ entries: revertedEntries }); + applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries); } if (response.status === "ok") { toast.success( diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index cc7582695..66e2ebd4a 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,12 +1,12 @@ -import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { useAtomValue, useSetAtom } from "jotai"; -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react"; -import { useMemo, useState } from "react"; -import { toast } from "sonner"; import { - agentActionByToolCallIdAtom, - markAgentActionRevertedAtom, -} from "@/atoms/chat/agent-actions.atom"; + type ToolCallMessagePartComponent, + useAuiState, +} from "@assistant-ui/react"; +import { useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; +import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { DoomLoopApprovalToolUI, @@ -24,8 +24,17 @@ import { AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; +import { Card } from "@/components/ui/card"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; +import { + markActionRevertedInCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; @@ -34,31 +43,128 @@ import { cn } from "@/lib/utils"; /** * Inline Revert button rendered on a tool card when the matching * ``AgentActionLog`` row is reversible and hasn't been reverted yet. - * Reads from the SSE side-channel atom keyed by the synthetic - * ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` - * is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). + * + * Reads from the unified ``useAgentActionsQuery`` cache — the SAME + * react-query cache the agent-actions sheet consumes. SSE events + * (``data-action-log`` / ``data-action-log-updated``) and + * ``POST /threads/{id}/revert/{id}`` responses both flow through the + * cache via ``setQueryData`` helpers, so the card and the sheet stay + * in lockstep on every code path: page reload, navigation, live + * stream, post-stream reversibility flip, and explicit revert clicks. + * + * Match key (in priority order): + * 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when + * the model streamed ``tool_call_chunks`` so the card's synthetic + * id IS the LangChain id. + * 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or + * parity_v2 with provider-side chunk emission) where the card's + * synthetic id is ``call_<run_id>`` and the LangChain id is + * backfilled onto the part by ``tool-output-available``. + * 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback + * for cards whose synthetic id is ``call_<run_id>`` AND whose + * ``langchainToolCallId`` never got backfilled (provider emitted + * the tool_call as a single payload with no chunks AND streaming + * pre-dated the ``tool-output-available langchainToolCallId`` + * backfill, e.g. older threads). Reads the parent message's + * ``chatTurnId`` and ``content`` via ``useAuiState`` so we can + * match position-by-tool-name within the turn against the + * action_log rows the server returned in ``created_at`` order. */ -function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { +function ToolCardRevertButton({ + toolCallId, + toolName, + langchainToolCallId, +}: { + toolCallId: string; + toolName: string; + langchainToolCallId?: string; +}) { const session = useAtomValue(chatSessionStateAtom); - const actionMap = useAtomValue(agentActionByToolCallIdAtom); - const markReverted = useSetAtom(markAgentActionRevertedAtom); - const action = actionMap.get(toolCallId); + const threadId = session?.threadId ?? null; + const queryClient = useQueryClient(); + const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId); + + // Parent message metadata, read via the narrowest possible + // selectors so this card doesn't re-render on every text-delta of + // every other part in the same message during streaming. + // + // IMPORTANT — ``useAuiState`` re-renders the component whenever the + // returned slice's identity changes. Returning ``message?.content`` + // (an array) would re-render on every token because the runtime + // rebuilds the parts array. Returning a PRIMITIVE (the position + // number) lets ``useAuiState``'s ``Object.is`` check short-circuit + // when the position hasn't actually moved — which is the common + // case during text streaming, when only ``text``/``reasoning`` + // parts are mutating and the same-toolName tool-call ordering is + // stable. (See Vercel React rule ``rerender-defer-reads``.) + const chatTurnId = useAuiState(({ message }) => { + const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined; + return meta?.custom?.chatTurnId ?? null; + }); + const positionInTurn = useAuiState(({ message }) => { + const content = message?.content; + if (!Array.isArray(content)) return -1; + let n = -1; + for (const part of content) { + if ( + part && + typeof part === "object" && + (part as { type?: string }).type === "tool-call" && + (part as { toolName?: string }).toolName === toolName + ) { + n += 1; + if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n; + } + } + return -1; + }); + + const action = useMemo(() => { + // Tier 1 + 2: O(1) Map-backed direct id match. Covers + // ~all parity_v2 streams and any legacy stream that backfilled + // ``langchainToolCallId`` via ``tool-output-available``. + const direct = + findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); + if (direct) return direct; + // Tier 3: position-within-turn fallback. Only kicks in when the + // card has a synthetic ``call_<run_id>`` id AND no + // ``langchainToolCallId`` was ever backfilled — i.e. the tool + // was emitted as a single non-chunked payload AND streaming + // pre-dated the on_tool_end backfill. + if (!chatTurnId || positionInTurn < 0) return null; + const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName); + return turnSameTool[positionInTurn] ?? null; + }, [ + findByToolCallId, + findByChatTurnAndTool, + toolCallId, + langchainToolCallId, + chatTurnId, + toolName, + positionInTurn, + ]); + const [isReverting, setIsReverting] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false); if (!action) return null; if (!action.reversible) return null; - if (action.revertedByActionId !== null) return null; - if (action.isRevertAction) return null; - if (action.error) return null; - const threadId = session?.threadId; + if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined) + return null; + if (action.is_revert_action) return null; + if (action.error !== null && action.error !== undefined) return null; if (!threadId) return null; const handleRevert = async () => { setIsReverting(true); try { const response = await agentActionsApiService.revert(threadId, action.id); - markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); + markActionRevertedInCache( + queryClient, + threadId, + action.id, + response.new_action_id ?? null + ); toast.success(response.message || "Action reverted."); } catch (err) { // 503 means revert is gated off on this deployment — hide the @@ -91,8 +197,17 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { e.stopPropagation(); setConfirmOpen(true); }} + disabled={isReverting} > - <RotateCcw className="size-3.5" /> + {isReverting ? ( + // Spinner's typed props don't accept ``data-icon`` and + // it renders an <output>, not an <svg>, so Button's + // auto-sizing rule doesn't apply. Bare spinner + + // Button's gap handle layout. + <Spinner size="xs" /> + ) : ( + <RotateCcw data-icon="inline-start" /> + )} Revert </Button> </AlertDialogTrigger> @@ -101,7 +216,7 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { <AlertDialogTitle>Revert this action?</AlertDialogTitle> <AlertDialogDescription> This will undo{" "} - <span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a + <span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a new entry to the history. Your chat is preserved — only the changes the agent made to your knowledge base or connected apps will be rolled back where possible. </AlertDialogDescription> @@ -114,8 +229,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { handleRevert(); }} disabled={isReverting} + className="gap-1.5" > - {isReverting ? "Reverting…" : "Revert"} + {isReverting && <Spinner size="xs" />} + Revert </AlertDialogAction> </AlertDialogFooter> </AlertDialogContent> @@ -123,18 +240,49 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { ); } -const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ - toolCallId, - toolName, - argsText, - result, - status, -}) => { - const [isExpanded, setIsExpanded] = useState(false); +/** + * Compact tool-call card. + * + * shadcn composition note: we intentionally use ``Card`` as a visual + * frame WITHOUT ``CardHeader / CardContent``. The full composition's + * ``p-6`` padding doesn't fit a compact collapsible header that IS the + * trigger; using ``Card`` alone preserves the rounded border, shadow, + * and ``bg-card`` token (semantic colors) without forcing a layout + * that doesn't fit. All status colors use semantic tokens — no manual + * dark-mode overrides, no raw hex. + */ +const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { + const { toolCallId, toolName, argsText, result, status } = props; + // ``langchainToolCallId`` is a SurfSense-specific extension the + // streaming pipeline attaches to the tool-call content part so + // the Revert button can resolve its ``AgentActionLog`` row even + // when only the LC id is known. assistant-ui's + // ``ToolCallMessagePartProps`` doesn't list it, but the runtime + // spreads ``{...part}`` so the prop reaches us at runtime. + const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId; const isCancelled = status?.type === "incomplete" && status.reason === "cancelled"; const isError = status?.type === "incomplete" && status.reason === "error"; const isRunning = status?.type === "running" || status?.type === "requires-action"; + + /* + Per-card expansion state. Initial value is ``isRunning`` so a + card streaming in mounts already-expanded (no flash of + collapsed → expanded on first paint), while a card loaded from + history (status="complete") mounts collapsed. The useEffect + below keeps this in lockstep with this card's own ``isRunning`` + when it transitions: false → true auto-expands (e.g. a tool + that re-runs after edit), true → false auto-collapses once the + tool finishes. Because the dep is per-card ``isRunning`` and + not the chat-level streaming flag, sibling cards on the same + assistant turn each manage their own expansion independently. + Once ``isRunning`` is false the user controls expansion via + ``onOpenChange``. + */ + const [isExpanded, setIsExpanded] = useState(isRunning); + useEffect(() => { + setIsExpanded(isRunning); + }, [isRunning]); const errorData = status?.type === "incomplete" ? status.error : undefined; const serializedError = useMemo( () => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null), @@ -160,108 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ : serializedError : null; - const Icon = getToolIcon(toolName); const displayName = getToolDisplayName(toolName); + const subtitle = errorReason ?? cancelledReason; return ( - <div + <Card className={cn( - "my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none", + "my-4 max-w-lg overflow-hidden", isCancelled && "opacity-60", - isError && "border-destructive/20 bg-destructive/5" + isError && "border-destructive/30" )} > - <button - type="button" - onClick={() => setIsExpanded((prev) => !prev)} - className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none" + {/* + ``group`` lets the chevron (rendered as a sibling of the + main trigger button) read the Collapsible Root's + ``data-[state=open]`` for rotation. The Collapsible is + fully controlled via ``isExpanded`` — the useEffect + above syncs it to ``isRunning`` so the card auto-opens + while a tool streams in and auto-collapses once it + finishes. We deliberately DON'T pass ``disabled`` so + both triggers stay clickable; ``onOpenChange`` is wired + to a setter that no-ops while ``isRunning`` (see + ``handleOpenChange`` below) which keeps the card pinned + open mid-stream without losing keyboard / pointer + affordance the moment streaming ends. + */} + <Collapsible + className="group" + open={isExpanded} + onOpenChange={(next) => { + // Block manual collapse while the tool is still + // streaming — otherwise a stray click on either + // trigger would close the card and hide the live + // ``argsText`` panel mid-run. After streaming the + // user has full control again. + if (isRunning) return; + setIsExpanded(next); + }} > - <div - className={cn( - "flex size-8 shrink-0 items-center justify-center rounded-lg", - isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10" - )} - > - {isError ? ( - <XCircleIcon className="size-4 text-destructive" /> - ) : isCancelled ? ( - <XCircleIcon className="size-4 text-muted-foreground" /> - ) : isRunning ? ( - <Icon className="size-4 text-primary animate-pulse" /> - ) : ( - <CheckIcon className="size-4 text-primary" /> - )} - </div> + {/* + Header row: main trigger on the left (icon + title + col), Revert + chevron-trigger on the right as + siblings of the main trigger. The chevron is wrapped + in its OWN ``CollapsibleTrigger`` (Radix supports + multiple triggers per Root) so clicking the chevron + toggles the same state as clicking the title row. + The Revert button stays a separate AlertDialog + trigger and stops propagation in its onClick so it + doesn't toggle the collapsible while opening the + confirm dialog. Keeping these as flat siblings — + rather than nesting Revert / chevron inside the + title trigger — avoids invalid HTML + (button-in-button) and lets the Revert button + render in BOTH the collapsed and expanded states. + */} + <div className="flex items-stretch transition-colors hover:bg-muted/50"> + <CollapsibleTrigger asChild> + <button + type="button" + className={cn( + "flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left", + // Inset ring — Card's ``overflow-hidden`` would + // clip an ``offset-2`` ring; ``ring-inset`` + // paints inside the button box. + "focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset", + "disabled:cursor-default" + )} + > + <div + className={cn( + "flex size-8 shrink-0 items-center justify-center rounded-lg", + isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10" + )} + > + {isError ? ( + <XCircleIcon className="size-4 text-destructive" /> + ) : isCancelled ? ( + <XCircleIcon className="size-4 text-muted-foreground" /> + ) : isRunning ? ( + <Spinner size="sm" className="text-primary" /> + ) : ( + <CheckIcon className="size-4 text-primary" /> + )} + </div> - <div className="flex-1 min-w-0"> - <p - className={cn( - "text-sm font-semibold", - isError - ? "text-destructive" - : isCancelled - ? "text-muted-foreground line-through" - : "text-foreground" - )} - > - {isRunning - ? displayName - : isCancelled - ? `Cancelled: ${displayName}` - : isError - ? `Failed: ${displayName}` - : displayName} - </p> - {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working…</p>} - {cancelledReason && ( - <p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p> - )} - {errorReason && ( - <p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p> - )} - </div> + <div className="flex flex-1 min-w-0 flex-col gap-0.5"> + <div className="flex items-center gap-2"> + <p + className={cn( + "text-sm font-semibold truncate", + isCancelled && "text-muted-foreground line-through", + isError && "text-destructive" + )} + > + {displayName} + </p> + {isRunning && <Badge variant="secondary">Running</Badge>} + {isError && <Badge variant="destructive">Failed</Badge>} + {isCancelled && <Badge variant="outline">Cancelled</Badge>} + </div> + {subtitle && ( + <p + className={cn( + "text-xs truncate", + isError ? "text-destructive/80" : "text-muted-foreground" + )} + > + {subtitle} + </p> + )} + </div> + </button> + </CollapsibleTrigger> - {!isRunning && ( - <div className="shrink-0 text-muted-foreground"> - {isExpanded ? ( - <ChevronDownIcon className="size-4" /> - ) : ( - <ChevronUpIcon className="size-4" /> - )} + {/* + Right-side controls. The Revert button is + visible whenever the matching action is + reversible — including the collapsed state — + but ``ToolCardRevertButton`` itself returns + ``null`` while a tool is still running because + no action-log row exists yet, so it doesn't + need an explicit ``isRunning`` gate here. + */} + <div className="flex shrink-0 items-center gap-2 pl-2 pr-5"> + <ToolCardRevertButton + toolCallId={toolCallId} + toolName={toolName} + langchainToolCallId={langchainToolCallId} + /> + <CollapsibleTrigger asChild> + <button + type="button" + aria-label={isExpanded ? "Collapse details" : "Expand details"} + className={cn( + "flex size-7 shrink-0 items-center justify-center rounded-md", + "text-muted-foreground hover:bg-muted hover:text-foreground", + "focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset", + "disabled:cursor-default" + )} + > + <ChevronDownIcon + className={cn( + "size-4 transition-transform duration-200", + "group-data-[state=open]:rotate-180" + )} + /> + </button> + </CollapsibleTrigger> </div> - )} - </button> + </div> - {isExpanded && !isRunning && ( - <> - <div className="mx-5 h-px bg-border/50" /> - <div className="px-5 py-3 space-y-3"> - {argsText && ( - <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p> - <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> - {argsText} - </pre> + {/* + CollapsibleContent body — auto-open while streaming + (see ``open`` prop above) so the live ``argsText`` + streams into the Inputs panel directly, no need for + a separate "Live input" panel. Native + ``overflow-auto`` instead of ``ScrollArea`` because + Radix's Viewport can let content bleed past + ``max-h-*`` in dynamic flex layouts. ``min-w-0`` on + the column wrappers guarantees ``break-all`` wraps + correctly within the bounded ``max-w-lg`` Card. + */} + <CollapsibleContent> + <Separator /> + <div className="flex flex-col gap-3 px-5 py-3"> + {(argsText || isRunning) && ( + <div className="flex flex-col gap-1 min-w-0"> + <p className="text-xs font-medium text-muted-foreground">Inputs</p> + <div className="max-h-48 overflow-auto rounded-md bg-muted/40"> + {argsText ? ( + <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> + {argsText} + </pre> + ) : ( + // Bridges the brief gap between + // ``tool-input-start`` (creates the + // card, ``argsText`` undefined) and + // the first ``tool-input-delta``. + <p className="px-3 py-2 text-xs italic text-muted-foreground"> + Waiting for input… + </p> + )} + </div> </div> )} {!isCancelled && result !== undefined && ( <> - <div className="h-px bg-border/30" /> - <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Result</p> - <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> - {typeof result === "string" ? result : serializedResult} - </pre> + <Separator /> + <div className="flex flex-col gap-1 min-w-0"> + <p className="text-xs font-medium text-muted-foreground">Result</p> + <div className="max-h-64 overflow-auto rounded-md bg-muted/40"> + <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> + {typeof result === "string" ? result : serializedResult} + </pre> + </div> </div> </> )} - <div className="flex justify-end"> - <ToolCardRevertButton toolCallId={toolCallId} /> - </div> </div> - </> - )} - </div> + </CollapsibleContent> + </Collapsible> + </Card> ); }; diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index bfdd613e2..05db99407 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -22,6 +22,7 @@ import { addToolCall, appendReasoning, appendText, + appendToolInputDelta, buildContentForUI, type ContentPartsState, endReasoning, @@ -146,6 +147,10 @@ export function FreeChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; try { for await (const parsed of readSSEStream(response)) { @@ -183,13 +188,20 @@ export function FreeChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -202,16 +214,20 @@ export function FreeChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, langchainToolCallId: parsed.langchainToolCallId, }); - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { diff --git a/surfsense_web/contracts/types/chat-messages.types.ts b/surfsense_web/contracts/types/chat-messages.types.ts index 0859f9f3b..ef16bb366 100644 --- a/surfsense_web/contracts/types/chat-messages.types.ts +++ b/surfsense_web/contracts/types/chat-messages.types.ts @@ -1,7 +1,13 @@ import { z } from "zod"; /** - * Raw message from database (real-time sync) + * Raw message from database (real-time sync). + * + * ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``) + * can populate ``metadata.custom.chatTurnId`` on the + * ``ThreadMessageLike`` even after the live-collab Zero re-sync. The + * inline Revert button's ``(chat_turn_id, tool_name, position)`` + * fallback in tool-fallback.tsx depends on it. */ export const rawMessage = z.object({ id: z.number(), @@ -10,6 +16,7 @@ export const rawMessage = z.object({ content: z.unknown(), author_id: z.string().nullable(), created_at: z.string(), + turn_id: z.string().nullable().optional(), }); export type RawMessage = z.infer<typeof rawMessage>; diff --git a/surfsense_web/hooks/use-agent-actions-query.ts b/surfsense_web/hooks/use-agent-actions-query.ts new file mode 100644 index 000000000..9a722fb2e --- /dev/null +++ b/surfsense_web/hooks/use-agent-actions-query.ts @@ -0,0 +1,416 @@ +"use client"; + +import { type QueryClient, useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useMemo, useRef } from "react"; +import { + type AgentAction, + type AgentActionListResponse, + agentActionsApiService, +} from "@/lib/apis/agent-actions-api.service"; + +// ============================================================================= +// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug`` +// to ``true`` to trace the full SSE → cache → card → button pipeline in +// the browser console. Off by default so we don't spam production. The +// infrastructure stays in place because the underlying id-mismatch +// failure mode is rare-but-real and surfaces only at runtime. +// ============================================================================= +const RevertDebug = false; +const dbg = (...args: unknown[]) => { + if (RevertDebug && typeof window !== "undefined") { + // eslint-disable-next-line no-console + console.log("[RevertDebug]", ...args); + } +}; + +/** + * Unified store for ``AgentActionLog`` rows scoped to one thread. + * + * Replaces the previous SSE side-channel atom mess + * (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` / + * ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook. + * One react-query cache entry is now the single source of truth for: + * + * * the inline Revert button on every tool-call card + * * the per-turn "Revert turn" button under each assistant message + * * the edit-from-position pre-flight that decides whether to show + * the confirmation dialog + * * the agent-actions sheet + * + * The cache is hydrated by ``GET /threads/{id}/actions`` (sized to + * 200, the server max) and updated incrementally by helpers that turn + * SSE events / revert RPC responses into ``setQueryData`` mutations. + * That keeps the card and the sheet in lockstep on every code path — + * page reload, navigation, live stream, post-stream reversibility flip, + * and explicit revert clicks. + */ + +export const ACTION_LOG_PAGE_SIZE = 200; + +/** Stable react-query key for the per-thread action list. */ +export function agentActionsQueryKey(threadId: number | null) { + return threadId !== null + ? (["agent-actions", threadId] as const) + : (["agent-actions", "none"] as const); +} + +/** Subset of the SSE ``data-action-log`` payload we care about. */ +export interface ActionLogSseEvent { + id: number; + lc_tool_call_id: string | null; + chat_turn_id: string | null; + tool_name: string; + reversible: boolean; + reverse_descriptor_present: boolean; + error: boolean; + created_at: string | null; +} + +/** + * Append or upsert a freshly-emitted ``AgentActionLog`` row into the + * thread-scoped query cache. + * + * The SSE payload is a strict subset of ``AgentAction``; missing + * fields (``args``, ``reverse_descriptor``, ``user_id``) are filled + * with ``null`` placeholders. The next refetch (sheet open, user + * focus, route stale) backfills them — but the inline Revert button + * only reads the fields the SSE payload carries, so it lights up + * immediately. + */ +export function applyActionLogSse( + queryClient: QueryClient, + threadId: number, + searchSpaceId: number, + event: ActionLogSseEvent +): void { + dbg("applyActionLogSse: incoming SSE event", { + threadId, + searchSpaceId, + event, + }); + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + const placeholder: AgentAction = { + id: event.id, + thread_id: threadId, + user_id: null, + search_space_id: searchSpaceId, + tool_name: event.tool_name, + args: null, + result_id: null, + reversible: event.reversible, + reverse_descriptor: event.reverse_descriptor_present ? {} : null, + error: event.error ? {} : null, + reverse_of: null, + reverted_by_action_id: null, + is_revert_action: false, + tool_call_id: event.lc_tool_call_id, + chat_turn_id: event.chat_turn_id, + created_at: event.created_at ?? new Date().toISOString(), + }; + if (!prev) { + return { + items: [placeholder], + total: 1, + page: 0, + page_size: ACTION_LOG_PAGE_SIZE, + has_more: false, + }; + } + const existingIdx = prev.items.findIndex((a) => a.id === event.id); + if (existingIdx >= 0) { + const merged = [...prev.items]; + const existing = merged[existingIdx]; + if (existing) { + merged[existingIdx] = { + ...existing, + reversible: event.reversible, + tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, + chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, + }; + } + dbg("applyActionLogSse: merged into existing entry", { + id: event.id, + tool_call_id: merged[existingIdx]?.tool_call_id, + reversible: merged[existingIdx]?.reversible, + }); + return { ...prev, items: merged }; + } + dbg("applyActionLogSse: appended new placeholder", { + id: event.id, + tool_call_id: placeholder.tool_call_id, + tool_name: placeholder.tool_name, + reversible: placeholder.reversible, + cacheSizeAfter: prev.items.length + 1, + }); + // REST returns newest-first — keep that ordering when + // the server eventually refetches by prepending. + return { + ...prev, + items: [placeholder, ...prev.items], + total: prev.total + 1, + }; + } + ); +} + +/** + * Apply a post-SAVEPOINT reversibility flip + * (``data-action-log-updated`` SSE event) to the cache. + */ +export function applyActionLogUpdatedSse( + queryClient: QueryClient, + threadId: number, + id: number, + reversible: boolean +): void { + dbg("applyActionLogUpdatedSse: reversibility flip", { + threadId, + id, + reversible, + }); + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) { + dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { + threadId, + id, + }); + return prev; + } + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + return { ...a, reversible }; + }); + if (!mutated) { + dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { + threadId, + id, + cacheSize: prev.items.length, + cacheIds: prev.items.map((a) => a.id), + }); + } + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Optimistically mark ``id`` as reverted. + * + * Used by the inline / per-turn Revert button immediately after the + * server returns success so the UI flips to "Reverted" without + * waiting for a refetch. ``newActionId`` is the id of the new + * ``is_revert_action`` row the server inserted; pass ``null`` if the + * server didn't return it. + */ +export function markActionRevertedInCache( + queryClient: QueryClient, + threadId: number, + id: number, + newActionId: number | null +): void { + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) return prev; + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + // ``-1`` is a sentinel meaning "we know it was reverted + // but the server didn't tell us the new row's id". + return { + ...a, + reverted_by_action_id: newActionId ?? -1, + }; + }); + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Apply a batch of revert results (per-turn revert response) to the + * cache. Anything in the ``reverted`` / ``already_reverted`` buckets + * gets its ``reverted_by_action_id`` set; other rows are left alone. + */ +export function applyRevertTurnResultsToCache( + queryClient: QueryClient, + threadId: number, + entries: Array<{ id: number; newActionId: number | null }> +): void { + if (entries.length === 0) return; + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) return prev; + const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); + let mutated = false; + const items = prev.items.map((a) => { + if (!lookup.has(a.id)) return a; + mutated = true; + const newActionId = lookup.get(a.id) ?? null; + return { ...a, reverted_by_action_id: newActionId ?? -1 }; + }); + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Read-side hook used by the card, the turn button, the sheet, and + * the edit-from-position pre-flight. + * + * Returns the raw query state plus convenience selectors so consumers + * don't reach into ``data.items`` directly. ``enabled`` is the only + * knob — pass ``false`` to keep the query dormant when the consumer + * doesn't yet have a thread id. + */ +export function useAgentActionsQuery( + threadId: number | null, + options: { enabled?: boolean } = {} +) { + const enabled = (options.enabled ?? true) && threadId !== null; + const query = useQuery({ + queryKey: agentActionsQueryKey(threadId), + queryFn: async () => { + dbg("useAgentActionsQuery: REST fetch START", { + threadId, + pageSize: ACTION_LOG_PAGE_SIZE, + }); + const res = await agentActionsApiService.listForThread(threadId as number, { + page: 0, + pageSize: ACTION_LOG_PAGE_SIZE, + }); + dbg("useAgentActionsQuery: REST fetch DONE", { + threadId, + total: res.total, + returned: res.items.length, + items: res.items.map((a) => ({ + id: a.id, + tool_name: a.tool_name, + tool_call_id: a.tool_call_id, + reversible: a.reversible, + reverted_by_action_id: a.reverted_by_action_id, + is_revert_action: a.is_revert_action, + })), + }); + return res; + }, + enabled, + staleTime: 15 * 1000, + }); + + const items = useMemo(() => query.data?.items ?? [], [query.data]); + + // Index ``items`` once per change so the lookups below are O(1) + // instead of O(N) per card per render. With the cache sized to 200 + // rows and many tool cards visible at once, the unindexed scan was + // the hottest path on every assistant text-delta. (Vercel React + // rule ``js-index-maps`` / ``js-set-map-lookups``.) + const byToolCallId = useMemo(() => { + const m = new Map<string, AgentAction>(); + for (const a of items) { + if (a.tool_call_id) m.set(a.tool_call_id, a); + } + return m; + }, [items]); + + // Pre-grouped + pre-sorted (oldest-first, the order the agent + // actually executed them in) so the (chat_turn_id, tool_name, + // position) fallback in ``tool-fallback.tsx`` is also O(1) per + // card. Excludes ``is_revert_action`` rows so the position index + // matches the agent's original execution order. + const byTurnAndTool = useMemo(() => { + const m = new Map<string, AgentAction[]>(); + for (const a of items) { + if (!a.chat_turn_id || a.is_revert_action) continue; + const key = `${a.chat_turn_id}::${a.tool_name}`; + const bucket = m.get(key); + if (bucket) bucket.push(a); + else m.set(key, [a]); + } + for (const bucket of m.values()) { + bucket.sort( + (a, b) => + new Date(a.created_at).getTime() - new Date(b.created_at).getTime() + ); + } + return m; + }, [items]); + + // Snapshot the cache shape when its size changes — easiest way to + // spot when the cache is empty or stale at the moment a card + // mounts. Tracked on a ref so we don't re-run the diff on + // reference-equal cache reads. + const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null); + useEffect(() => { + const last = lastSnapshotRef.current; + if (!last || last.threadId !== threadId || last.size !== items.length) { + dbg("useAgentActionsQuery: cache snapshot", { + threadId, + enabled, + itemCount: items.length, + itemKeys: items.slice(0, 8).map((a) => ({ + id: a.id, + tool_name: a.tool_name, + tool_call_id: a.tool_call_id, + chat_turn_id: a.chat_turn_id, + reversible: a.reversible, + })), + }); + lastSnapshotRef.current = { threadId, size: items.length }; + } + }, [threadId, enabled, items]); + + const findByToolCallId = useCallback( + (toolCallId: string | null | undefined): AgentAction | null => { + if (!toolCallId) return null; + const found = byToolCallId.get(toolCallId) ?? null; + if (!found && items.length > 0) { + dbg("findByToolCallId: MISS", { + queriedToolCallId: toolCallId, + itemCount: items.length, + availableToolCallIds: Array.from(byToolCallId.keys()), + }); + } + return found; + }, + [byToolCallId, items.length] + ); + + const findByChatTurnId = useCallback( + (chatTurnId: string | null | undefined): AgentAction[] => { + if (!chatTurnId) return []; + // Per-turn aggregation is uncommon enough (only the + // "Revert turn" button uses it) that re-scanning is fine; + // indexing it would just bloat memory. + return items.filter((a) => a.chat_turn_id === chatTurnId); + }, + [items] + ); + + const findByChatTurnAndTool = useCallback( + ( + chatTurnId: string | null | undefined, + toolName: string | null | undefined + ): AgentAction[] => { + if (!chatTurnId || !toolName) return []; + return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? []; + }, + [byTurnAndTool] + ); + + return { + ...query, + items, + findByToolCallId, + findByChatTurnId, + findByChatTurnAndTool, + }; +} diff --git a/surfsense_web/hooks/use-messages-sync.ts b/surfsense_web/hooks/use-messages-sync.ts index ddbe8a757..5ccda23a5 100644 --- a/surfsense_web/hooks/use-messages-sync.ts +++ b/surfsense_web/hooks/use-messages-sync.ts @@ -31,6 +31,14 @@ export function useMessagesSync( content: msg.content, author_id: msg.authorId ?? null, created_at: new Date(msg.createdAt).toISOString(), + // Forward the per-turn correlation id so post-stream Zero + // re-syncs preserve ``metadata.custom.chatTurnId`` on the + // converted ``ThreadMessageLike``. Without this the inline + // Revert button's ``(chat_turn_id, tool_name, position)`` + // fallback breaks the moment Zero overwrites the messages + // state after a live stream completes (see + // ``handleSyncedMessagesUpdate`` in the chat page). + turn_id: msg.turnId ?? null, })); onMessagesUpdateRef.current(mapped); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 26fd7b98c..54faf7e7c 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -16,6 +16,23 @@ export type ContentPart = toolName: string; args: Record<string, unknown>; result?: unknown; + /** + * Live / finalized JSON text for the tool's input arguments. + * + * - During streaming: accumulated partial JSON text from + * ``tool-input-delta`` events (may be invalid JSON + * mid-stream). assistant-ui's argsText parser tolerates + * invalid JSON gracefully (changelog 0.7.32 / 0.7.78). + * - On completion (``tool-input-available``): replaced with + * ``JSON.stringify(input, null, 2)`` so the post-stream + * card renders pretty-printed JSON instead of the + * model's possibly-fragmented formatting. + * + * Per assistant-ui ``ThreadMessageLike`` precedence + * (changelog 0.11.6 ``d318c83``), when ``argsText`` is + * supplied it wins over ``JSON.stringify(args)``. + */ + argsText?: string; /** * Authoritative LangChain ``tool_call.id`` propagated by the backend * via ``langchainToolCallId`` on tool-input-start/available and @@ -282,12 +299,22 @@ export function findToolCallIdByLcId( export function updateToolCall( state: ContentPartsState, toolCallId: string, - update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string } + update: { + args?: Record<string, unknown>; + argsText?: string; + result?: unknown; + langchainToolCallId?: string; + } ): void { const index = state.toolCallIndices.get(toolCallId); if (index !== undefined && state.contentParts[index]?.type === "tool-call") { const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; if (update.args) tc.args = update.args; + // ``!== undefined`` (NOT a truthy check): an explicit empty + // string CAN clear, and a finalization with + // ``JSON.stringify({}, null, 2) === "{}"`` (truthy but + // represents an empty-input call) still applies. + if (update.argsText !== undefined) tc.argsText = update.argsText; if (update.result !== undefined) tc.result = update.result; // Only backfill langchainToolCallId if not already set — the // authoritative ``on_tool_end`` value should override an earlier @@ -299,6 +326,25 @@ export function updateToolCall( } } +/** + * Append a streamed args-delta chunk to the active tool call's + * ``argsText``. No-ops when no card has been registered yet for the + * given ``toolCallId`` (the matching ``tool-input-start`` either lost + * the wire race or this id never had a card — either way the deltas + * have nowhere safe to land). + */ +export function appendToolInputDelta( + state: ContentPartsState, + toolCallId: string, + delta: string +): void { + const idx = state.toolCallIndices.get(toolCallId); + if (idx === undefined) return; + const tc = state.contentParts[idx]; + if (tc?.type !== "tool-call") return; + tc.argsText = (tc.argsText ?? "") + delta; +} + function _hasInterruptResult(part: ContentPart): boolean { if (part.type !== "tool-call") return false; const r = (part as { result?: unknown }).result; @@ -371,6 +417,18 @@ export type SSEEvent = /** Authoritative LangChain ``tool_call.id``. Optional. */ langchainToolCallId?: string; } + | { + /** + * Live tool-call argument delta. Concatenated into + * ``argsText`` on the matching ``tool-call`` content part + * by ``appendToolInputDelta``. parity_v2 only — the legacy + * code path emits ``tool-input-available`` without prior + * deltas. + */ + type: "tool-input-delta"; + toolCallId: string; + inputTextDelta: string; + } | { type: "tool-input-available"; toolCallId: string; diff --git a/surfsense_web/zero/schema/chat.ts b/surfsense_web/zero/schema/chat.ts index 0293059fd..fb3d7651e 100644 --- a/surfsense_web/zero/schema/chat.ts +++ b/surfsense_web/zero/schema/chat.ts @@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages") threadId: number().from("thread_id"), authorId: string().optional().from("author_id"), createdAt: number().from("created_at"), + // Per-turn correlation id sourced from ``configurable.turn_id`` + // at streaming time. Required by the inline Revert button's + // (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx + // — without it the live-collab Zero sync would clobber the + // metadata we set during streaming and the button would vanish + // the moment Zero re-syncs after the stream finishes. + turnId: string().optional().from("turn_id"), }) .primaryKey("id"); From 1ce122cc99cab31fde5323692b1c835f3a763cd5 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:05:58 +0530 Subject: [PATCH 244/299] feat(database): change alembic number and add idempotency --- ...34_add_thread_auto_model_pinning_fields.py | 63 ---------------- ...38_add_thread_auto_model_pinning_fields.py | 72 +++++++++++++++++++ 2 files changed, 72 insertions(+), 63 deletions(-) delete mode 100644 surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py create mode 100644 surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py diff --git a/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py deleted file mode 100644 index ab1643b02..000000000 --- a/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py +++ /dev/null @@ -1,63 +0,0 @@ -"""134_add_thread_auto_model_pinning_fields - -Revision ID: 134 -Revises: 133 -Create Date: 2026-04-29 - -Add thread-level fields to persist Auto (Fastest) model pinning metadata: -- pinned_llm_config_id: concrete resolved config id used for this thread -- pinned_auto_mode: auto policy identifier (currently "auto_fastest") -- pinned_at: timestamp when the pin was created/refreshed -""" - -from __future__ import annotations - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "134" -down_revision: str | None = "133" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - op.add_column( - "new_chat_threads", - sa.Column("pinned_llm_config_id", sa.Integer(), nullable=True), - ) - op.add_column( - "new_chat_threads", - sa.Column("pinned_auto_mode", sa.String(length=32), nullable=True), - ) - op.add_column( - "new_chat_threads", - sa.Column("pinned_at", sa.TIMESTAMP(timezone=True), nullable=True), - ) - - op.create_index( - "ix_new_chat_threads_pinned_llm_config_id", - "new_chat_threads", - ["pinned_llm_config_id"], - unique=False, - ) - op.create_index( - "ix_new_chat_threads_pinned_auto_mode", - "new_chat_threads", - ["pinned_auto_mode"], - unique=False, - ) - - -def downgrade() -> None: - op.drop_index("ix_new_chat_threads_pinned_auto_mode", table_name="new_chat_threads") - op.drop_index( - "ix_new_chat_threads_pinned_llm_config_id", table_name="new_chat_threads" - ) - - op.drop_column("new_chat_threads", "pinned_at") - op.drop_column("new_chat_threads", "pinned_auto_mode") - op.drop_column("new_chat_threads", "pinned_llm_config_id") diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py new file mode 100644 index 000000000..6e4b77cc7 --- /dev/null +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -0,0 +1,72 @@ +"""138_add_thread_auto_model_pinning_fields + +Revision ID: 138 +Revises: 137 +Create Date: 2026-04-30 + +Add thread-level fields to persist Auto (Fastest) model pinning metadata: +- pinned_llm_config_id: concrete resolved config id used for this thread +- pinned_auto_mode: auto policy identifier (currently "auto_fastest") +- pinned_at: timestamp when the pin was created/refreshed + +Idempotent: this migration was originally numbered 134 on the +``feat/split-auto-free-premium`` branch and was renumbered to 138 during +the merge with ``upstream/dev`` (which claimed 134-137). Some databases +already have these columns/indexes from when the original 134 ran, so we +use ``IF NOT EXISTS`` to make re-application a no-op for those DBs while +still creating the schema on fresh databases. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "138" +down_revision: str | None = "137" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" + ) + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)" + ) + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE" + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id " + "ON new_chat_threads (pinned_llm_config_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode " + "ON new_chat_threads (pinned_auto_mode)" + ) + + +def downgrade() -> None: + op.execute( + "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode" + ) + op.execute( + "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id" + ) + + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at" + ) + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode" + ) + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id" + ) From 2a01711bc9f966f25fe652fb2063357cb74ec99b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:20:14 +0530 Subject: [PATCH 245/299] feat(chat): expand error handling for chat operations by introducing a passthrough code set, improving response management and user feedback --- ...138_add_thread_auto_model_pinning_fields.py | 7 ------- .../unit/test_stream_new_chat_contract.py | 16 ++++++++++++---- .../new-chat/[[...chat_id]]/page.tsx | 18 ++++++++++++++---- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 6e4b77cc7..1ea549975 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -8,13 +8,6 @@ Add thread-level fields to persist Auto (Fastest) model pinning metadata: - pinned_llm_config_id: concrete resolved config id used for this thread - pinned_auto_mode: auto policy identifier (currently "auto_fastest") - pinned_at: timestamp when the pin was created/refreshed - -Idempotent: this migration was originally numbered 134 on the -``feat/split-auto-free-premium`` branch and was renumbered to 138 during -the merge with ``upstream/dev`` (which claimed 134-137). Some databases -already have these columns/indexes from when the original 134 ran, so we -use ``IF NOT EXISTS`` to make re-application a no-op for those DBs while -still creating the schema on fresh databases. """ from __future__ import annotations diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 9f4280063..86ea7edd1 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -231,10 +231,18 @@ def test_network_send_failures_use_unified_retry_toast_message(): assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source assert "tagPreAcceptSendFailure(error)" in page_source - assert 'existingCode === "THREAD_BUSY"' in page_source - assert 'existingCode === "AUTH_EXPIRED"' in page_source - assert 'existingCode === "UNAUTHORIZED"' in page_source - assert 'existingCode === "RATE_LIMITED"' in page_source + assert "const passthroughCodes = new Set([" in page_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source + assert '"THREAD_BUSY"' in page_source + assert '"AUTH_EXPIRED"' in page_source + assert '"UNAUTHORIZED"' in page_source + assert '"RATE_LIMITED"' in page_source + assert '"NETWORK_ERROR"' in page_source + assert '"STREAM_PARSE_ERROR"' in page_source + assert '"TOOL_EXECUTION_ERROR"' in page_source + assert '"PERSIST_MESSAGE_FAILED"' in page_source + assert '"SERVER_ERROR"' in page_source + assert "passthroughCodes.has(existingCode)" in page_source assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source assert 'errorCode: "NETWORK_ERROR"' not in page_source assert "Failed to start chat. Please try again." not in page_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f21a0a30b..239afaf73 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -227,11 +227,21 @@ function tagPreAcceptSendFailure(error: unknown): unknown { if (error instanceof Error) { const withCode = error as Error & { errorCode?: string; code?: string }; const existingCode = withCode.errorCode ?? withCode.code; + const passthroughCodes = new Set([ + "PREMIUM_QUOTA_EXHAUSTED", + "THREAD_BUSY", + "AUTH_EXPIRED", + "UNAUTHORIZED", + "RATE_LIMITED", + "NETWORK_ERROR", + "STREAM_PARSE_ERROR", + "TOOL_EXECUTION_ERROR", + "PERSIST_MESSAGE_FAILED", + "SERVER_ERROR", + ]); if ( - existingCode === "THREAD_BUSY" || - existingCode === "AUTH_EXPIRED" || - existingCode === "UNAUTHORIZED" || - existingCode === "RATE_LIMITED" + existingCode && + passthroughCodes.has(existingCode) ) { return Object.assign(error, { errorCode: existingCode }); } From 1d6d7e3eb10f814aafdff5430af4033af04bb176 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:33:13 +0530 Subject: [PATCH 246/299] refactor(chat): remove unused agent action handlers from NewChatPage component to streamline code and improve maintainability --- .../[search_space_id]/new-chat/[[...chat_id]]/page.tsx | 7 ------- 1 file changed, 7 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index df1290971..fe625f169 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1512,8 +1512,6 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, - upsertAgentAction, - updateAgentActionReversible, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1894,8 +1892,6 @@ export default function NewChatPage() { messages, searchSpaceId, tokenUsageStore, - upsertAgentAction, - updateAgentActionReversible, handleStreamTerminalError, persistAssistantTurn, ] @@ -2433,9 +2429,6 @@ export default function NewChatPage() { messageDocumentsMap, setMessageDocumentsMap, tokenUsageStore, - upsertAgentAction, - updateAgentActionReversible, - markAgentActionReverted, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, From 6465ea181a25a8c6d003572ea4707aa9e1dcf3cc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:09:18 +0530 Subject: [PATCH 247/299] refactor(chat): streamline NewChatPage component by removing unused functions and integrating new stream handling utilities for improved performance --- .../new-chat/[[...chat_id]]/page.tsx | 625 +++++++----------- 1 file changed, 255 insertions(+), 370 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index fe625f169..d1dd14e06 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -252,6 +252,168 @@ function tagPreAcceptSendFailure(error: unknown): unknown { }); } +type SharedStreamEventContext = { + contentPartsState: ContentPartsState; + toolsWithUI: ToolUIGate; + currentThinkingSteps: Map<string, ThinkingStepData>; + scheduleFlush: () => void; + forceFlush: () => void; + onTokenUsage?: (data: TokenUsageData) => void; + onToolOutputAvailable?: ( + event: Extract<SSEEvent, { type: "tool-output-available" }>, + context: { + contentPartsState: ContentPartsState; + toolCallIndices: Map<string, number>; + } + ) => void; +}; + +function createStreamFlushHelpers(flushMessages: () => void): { + batcher: FrameBatchedUpdater; + scheduleFlush: () => void; + forceFlush: () => void; +} { + const batcher = new FrameBatchedUpdater(); + const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text streamed). + // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so + // terminal events render promptly without the throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; + return { batcher, scheduleFlush, forceFlush }; +} + +function hasPersistableContent(contentParts: ContentPartsState["contentParts"], toolsWithUI: ToolUIGate) { + return contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); +} + +function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; + const { contentParts, toolCallIndices } = contentPartsState; + + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + return true; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + return true; + + case "finish-step": + return true; + + case "tool-input-start": + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); + forceFlush(); + return true; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens of times per call, + // so use throttled scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + return true; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + argsText: finalArgsText, + langchainToolCallId: parsed.langchainToolCallId, + }); + } else { + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {}, + false, + parsed.langchainToolCallId + ); + // addToolCall doesn't accept argsText today; backfill via + // updateToolCall so the new card renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); + } + forceFlush(); + return true; + } + + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); + markInterruptsCompleted(contentParts); + context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); + forceFlush(); + return true; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } + } + return true; + } + + case "data-token-usage": + context.onTokenUsage?.(parsed.data as TokenUsageData); + return true; + + case "error": + throw toStreamTerminalError(parsed); + + default: + return false; + } +} + +async function consumeSseEvents( + response: Response, + onEvent: (event: SSEEvent) => void | Promise<void> +): Promise<void> { + for await (const parsed of readSSEStream(response)) { + await onEvent(parsed); + } +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -456,7 +618,7 @@ export default function NewChatPage() { threadId: number | null; assistantMsgId: string; content: unknown; - tokenUsage?: Record<string, unknown>; + tokenUsage?: TokenUsageData; turnId?: string | null; logContext: string; onRemapped?: (newMsgId: string) => void; @@ -1055,8 +1217,6 @@ export default function NewChatPage() { // Prepare assistant message const assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map<string, ThinkingStepData>(); - const batcher = new FrameBatchedUpdater(); - const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, @@ -1065,11 +1225,12 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; - let tokenUsageData: Record<string, unknown> | null = null; + let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; let userPersisted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1152,123 +1313,37 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text - // streamed). ``scheduleFlush(); batcher.flush()`` sets - // the dirty bit FIRST so terminal events render - // promptly without the 50ms throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens - // of times per call, so use throttled - // scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; - // backfill via updateToolCall so the new card - // renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - } - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-thread-title-update": { const titleData = parsed.data as { threadId: number; title: string }; if (titleData?.title && titleData?.threadId === currentThreadId) { @@ -1374,16 +1449,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -1425,7 +1492,7 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, currentThreadId); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "new", @@ -1448,13 +1515,7 @@ export default function NewChatPage() { } } - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -1543,7 +1604,6 @@ export default function NewChatPage() { abortControllerRef.current = controller; const currentThinkingSteps = new Map<string, ThinkingStepData>(); - const batcher = new FrameBatchedUpdater(); const contentPartsState: ContentPartsState = { contentParts: [], @@ -1552,10 +1612,11 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; - let tokenUsageData: Record<string, unknown> | null = null; + let tokenUsageData: TokenUsageData | null = null; let resumeAccepted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1664,102 +1725,26 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + }) + ) { + return; + } switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - case "data-interrupt-request": { const interruptData = parsed.data as Record<string, unknown>; const actionRequests = (interruptData.action_requests ?? []) as Array<{ @@ -1830,16 +1815,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -1855,7 +1832,7 @@ export default function NewChatPage() { }); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "resume", @@ -1864,13 +1841,7 @@ export default function NewChatPage() { accepted: resumeAccepted, onAbort: async () => { if (!resumeAccepted) return; - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -1891,6 +1862,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + queryClient, tokenUsageStore, handleStreamTerminalError, persistAssistantTurn, @@ -2045,15 +2017,15 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; - const batcher = new FrameBatchedUpdater(); - let tokenUsageData: Record<string, unknown> | null = null; + const { contentParts } = contentPartsState; + let tokenUsageData: TokenUsageData | null = null; let regenerateAccepted = false; let userPersisted = false; // Captured from ``data-turn-info`` at stream start; stamped // onto persisted messages so future edits can locate the // right LangGraph checkpoint. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -2155,111 +2127,37 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-action-log": { if (threadId !== null) { applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); @@ -2326,16 +2224,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -2364,7 +2254,7 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "regenerate", @@ -2384,13 +2274,7 @@ export default function NewChatPage() { }); userPersisted = Boolean(persistedUserMsgId); } - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -2428,6 +2312,7 @@ export default function NewChatPage() { disabledTools, messageDocumentsMap, setMessageDocumentsMap, + queryClient, tokenUsageStore, handleStreamTerminalError, persistAssistantTurn, From 86f6b285ce9cedbf529a7d8325f4457f602f997a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:09:34 +0530 Subject: [PATCH 248/299] refactor(chat): introduce new stream handling utilities and restructure event processing for improved performance and maintainability --- .../new-chat/[[...chat_id]]/page.tsx | 205 +----------------- surfsense_web/lib/chat/stream-flush.ts | 19 ++ surfsense_web/lib/chat/stream-pipeline.ts | 191 ++++++++++++++++ 3 files changed, 217 insertions(+), 198 deletions(-) create mode 100644 surfsense_web/lib/chat/stream-flush.ts create mode 100644 surfsense_web/lib/chat/stream-pipeline.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index d1dd14e06..82a12b6b1 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -71,23 +71,21 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { - addStepSeparator, addToolCall, - appendReasoning, - appendText, - appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, - endReasoning, - FrameBatchedUpdater, - readSSEStream, - type SSEEvent, + type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; +import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; +import { + consumeSseEvents, + hasPersistableContent, + processSharedStreamEvent, +} from "@/lib/chat/stream-pipeline"; import { appendMessage, createThread, @@ -134,33 +132,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -/** - * After a tool produces output, mark any previously-decided interrupt tool - * calls as completed so the ApprovalCard can transition from shimmer to done. - */ -function markInterruptsCompleted(contentParts: Array<{ type: string; result?: unknown }>): void { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - (part.result as Record<string, unknown>).__interrupt__ === true && - (part.result as Record<string, unknown>).__decided__ && - !(part.result as Record<string, unknown>).__completed__ - ) { - part.result = { ...(part.result as Record<string, unknown>), __completed__: true }; - } - } -} - -function toStreamTerminalError( - event: Extract<SSEEvent, { type: "error" }> -): Error & { errorCode?: string } { - return Object.assign(new Error(event.errorText || "Server error"), { - errorCode: event.errorCode, - }); -} - async function toHttpResponseError(response: Response): Promise<Error & { errorCode?: string }> { const statusDefaultCode = response.status === 409 @@ -252,168 +223,6 @@ function tagPreAcceptSendFailure(error: unknown): unknown { }); } -type SharedStreamEventContext = { - contentPartsState: ContentPartsState; - toolsWithUI: ToolUIGate; - currentThinkingSteps: Map<string, ThinkingStepData>; - scheduleFlush: () => void; - forceFlush: () => void; - onTokenUsage?: (data: TokenUsageData) => void; - onToolOutputAvailable?: ( - event: Extract<SSEEvent, { type: "tool-output-available" }>, - context: { - contentPartsState: ContentPartsState; - toolCallIndices: Map<string, number>; - } - ) => void; -}; - -function createStreamFlushHelpers(flushMessages: () => void): { - batcher: FrameBatchedUpdater; - scheduleFlush: () => void; - forceFlush: () => void; -} { - const batcher = new FrameBatchedUpdater(); - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text streamed). - // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so - // terminal events render promptly without the throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; - return { batcher, scheduleFlush, forceFlush }; -} - -function hasPersistableContent(contentParts: ContentPartsState["contentParts"], toolsWithUI: ToolUIGate) { - return contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); -} - -function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { - const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; - const { contentParts, toolCallIndices } = contentPartsState; - - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - return true; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - return true; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - return true; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - return true; - - case "finish-step": - return true; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - return true; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens of times per call, - // so use throttled scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - return true; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; backfill via - // updateToolCall so the new card renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - return true; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); - forceFlush(); - return true; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - return true; - } - - case "data-token-usage": - context.onTokenUsage?.(parsed.data as TokenUsageData); - return true; - - case "error": - throw toStreamTerminalError(parsed); - - default: - return false; - } -} - -async function consumeSseEvents( - response: Response, - onEvent: (event: SSEEvent) => void | Promise<void> -): Promise<void> { - for await (const parsed of readSSEStream(response)) { - await onEvent(parsed); - } -} - /** * Zod schema for mentioned document info (for type-safe parsing) */ diff --git a/surfsense_web/lib/chat/stream-flush.ts b/surfsense_web/lib/chat/stream-flush.ts new file mode 100644 index 000000000..6d13c9237 --- /dev/null +++ b/surfsense_web/lib/chat/stream-flush.ts @@ -0,0 +1,19 @@ +import { FrameBatchedUpdater } from "@/lib/chat/streaming-state"; + +export function createStreamFlushHelpers(flushMessages: () => void): { + batcher: FrameBatchedUpdater; + scheduleFlush: () => void; + forceFlush: () => void; +} { + const batcher = new FrameBatchedUpdater(); + const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text streamed). + // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so + // terminal events render promptly without the throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; + return { batcher, scheduleFlush, forceFlush }; +} diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts new file mode 100644 index 000000000..8957bdea3 --- /dev/null +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -0,0 +1,191 @@ +import { + addStepSeparator, + addToolCall, + appendReasoning, + appendText, + appendToolInputDelta, + type ContentPartsState, + endReasoning, + readSSEStream, + type SSEEvent, + type ThinkingStepData, + type ToolUIGate, + updateThinkingSteps, + updateToolCall, +} from "@/lib/chat/streaming-state"; + +export type SharedStreamEventContext = { + contentPartsState: ContentPartsState; + toolsWithUI: ToolUIGate; + currentThinkingSteps: Map<string, ThinkingStepData>; + scheduleFlush: () => void; + forceFlush: () => void; + onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void; + onToolOutputAvailable?: ( + event: Extract<SSEEvent, { type: "tool-output-available" }>, + context: { + contentPartsState: ContentPartsState; + toolCallIndices: Map<string, number>; + } + ) => void; +}; + +/** + * After a tool produces output, mark any previously-decided interrupt tool + * calls as completed so the ApprovalCard can transition from shimmer to done. + */ +export function markInterruptsCompleted( + contentParts: Array<{ type: string; result?: unknown }> +): void { + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + (part.result as Record<string, unknown>).__interrupt__ === true && + (part.result as Record<string, unknown>).__decided__ && + !(part.result as Record<string, unknown>).__completed__ + ) { + part.result = { ...(part.result as Record<string, unknown>), __completed__: true }; + } + } +} + +export function hasPersistableContent( + contentParts: ContentPartsState["contentParts"], + toolsWithUI: ToolUIGate +) { + return contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); +} + +function toStreamTerminalError( + event: Extract<SSEEvent, { type: "error" }> +): Error & { errorCode?: string } { + return Object.assign(new Error(event.errorText || "Server error"), { + errorCode: event.errorCode, + }); +} + +export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; + const { contentParts, toolCallIndices } = contentPartsState; + + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + return true; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + return true; + + case "finish-step": + return true; + + case "tool-input-start": + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); + forceFlush(); + return true; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens of times per call, + // so use throttled scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + return true; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + argsText: finalArgsText, + langchainToolCallId: parsed.langchainToolCallId, + }); + } else { + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {}, + false, + parsed.langchainToolCallId + ); + // addToolCall doesn't accept argsText today; backfill via + // updateToolCall so the new card renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); + } + forceFlush(); + return true; + } + + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); + markInterruptsCompleted(contentParts); + context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); + forceFlush(); + return true; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } + } + return true; + } + + case "data-token-usage": + context.onTokenUsage?.(parsed.data); + return true; + + case "error": + throw toStreamTerminalError(parsed); + + default: + return false; + } +} + +export async function consumeSseEvents( + response: Response, + onEvent: (event: SSEEvent) => void | Promise<void> +): Promise<void> { + for await (const parsed of readSSEStream(response)) { + await onEvent(parsed); + } +} From d65a3fdf76364b0705eaff0953f4d7283ecafde2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:22:34 +0530 Subject: [PATCH 249/299] refactor(chat): implement new error handling utilities and streamline interrupt request processing in NewChatPage for improved performance and maintainability --- .../new-chat/[[...chat_id]]/page.tsx | 238 +++--------------- surfsense_web/lib/chat/chat-request-errors.ts | 89 +++++++ surfsense_web/lib/chat/stream-side-effects.ts | 127 ++++++++++ 3 files changed, 246 insertions(+), 208 deletions(-) create mode 100644 surfsense_web/lib/chat/chat-request-errors.ts create mode 100644 surfsense_web/lib/chat/stream-side-effects.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 82a12b6b1..02c2914be 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -64,6 +64,10 @@ import { classifyChatError, type ChatFlow, } from "@/lib/chat/chat-error-classifier"; +import { + tagPreAcceptSendFailure, + toHttpResponseError, +} from "@/lib/chat/chat-request-errors"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, @@ -71,14 +75,12 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { - addToolCall, buildContentForPersistence, buildContentForUI, type ContentPartsState, type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateToolCall, } from "@/lib/chat/streaming-state"; import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; import { @@ -86,6 +88,14 @@ import { hasPersistableContent, processSharedStreamEvent, } from "@/lib/chat/stream-pipeline"; +import { + applyTurnIdToAssistantMessageList, + applyInterruptRequestToContentParts, + mergeChatTurnIdIntoMessage, + mergeEditedInterruptAction, + markInterruptDecisionOnContentParts, + readStreamedChatTurnId, +} from "@/lib/chat/stream-side-effects"; import { appendMessage, createThread, @@ -132,97 +142,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -async function toHttpResponseError(response: Response): Promise<Error & { errorCode?: string }> { - const statusDefaultCode = - response.status === 409 - ? "THREAD_BUSY" - : response.status === 429 - ? "RATE_LIMITED" - : response.status === 401 || response.status === 403 - ? "AUTH_EXPIRED" - : "SERVER_ERROR"; - - let rawBody = ""; - try { - rawBody = await response.text(); - } catch { - // noop - } - - let parsedBody: Record<string, unknown> | null = null; - if (rawBody) { - try { - const parsed = JSON.parse(rawBody); - if (typeof parsed === "object" && parsed !== null) { - parsedBody = parsed as Record<string, unknown>; - } - } catch { - // noop - } - } - - const detail = parsedBody?.detail; - const detailObject = - typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null; - const detailMessage = typeof detail === "string" ? detail : undefined; - const topLevelMessage = - typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; - const detailNestedMessage = - typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; - - const topLevelCode = - typeof parsedBody?.errorCode === "string" - ? parsedBody.errorCode - : typeof parsedBody?.error_code === "string" - ? parsedBody.error_code - : undefined; - const detailCode = - typeof detailObject?.errorCode === "string" - ? detailObject.errorCode - : typeof detailObject?.error_code === "string" - ? detailObject.error_code - : undefined; - - const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; - const message = - detailNestedMessage ?? - detailMessage ?? - topLevelMessage ?? - `Backend error: ${response.status}`; - - return Object.assign(new Error(message), { errorCode }); -} - -function tagPreAcceptSendFailure(error: unknown): unknown { - if (error instanceof Error) { - const withCode = error as Error & { errorCode?: string; code?: string }; - const existingCode = withCode.errorCode ?? withCode.code; - const passthroughCodes = new Set([ - "PREMIUM_QUOTA_EXHAUSTED", - "THREAD_BUSY", - "AUTH_EXPIRED", - "UNAUTHORIZED", - "RATE_LIMITED", - "NETWORK_ERROR", - "STREAM_PARSE_ERROR", - "TOOL_EXECUTION_ERROR", - "PERSIST_MESSAGE_FAILED", - "SERVER_ERROR", - ]); - if ( - existingCode && - passthroughCodes.has(existingCode) - ) { - return Object.assign(error, { errorCode: existingCode }); - } - return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); - } - - return Object.assign(new Error("Failed to send message before stream acceptance"), { - errorCode: "SEND_FAILED_PRE_ACCEPT", - }); -} - /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -264,29 +183,6 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; -/** - * When a streamed message is persisted, the backend returns the durable - * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it - * into the assistant-ui message metadata so the per-turn "Revert turn" - * button can scope to this turn's actions even after a full chat reload. - */ -function mergeChatTurnIdIntoMessage( - msg: ThreadMessageLike, - turnId: string | null | undefined -): ThreadMessageLike { - if (!turnId) return msg; - const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; - const existingCustom = existingMeta.custom ?? {}; - if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; - return { - ...msg, - metadata: { - ...existingMeta, - custom: { ...existingCustom, chatTurnId: turnId }, - }, - }; -} - export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -1032,7 +928,7 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; + const { contentParts } = contentPartsState; let wasInterrupted = false; let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; @@ -1194,27 +1090,7 @@ export default function NewChatPage() { case "data-interrupt-request": { wasInterrupted = true; const interruptData = parsed.data as Record<string, unknown>; - const actionRequests = (interruptData.action_requests ?? []) as Array<{ - name: string; - args: Record<string, unknown>; - }>; - for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { - result: { __interrupt__: true, ...interruptData }, - }); - } else { - const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); - updateToolCall(contentPartsState, tcId, { - result: { __interrupt__: true, ...interruptData }, - }); - } - } + applyInterruptRequestToContentParts(contentPartsState, toolsWithUI, interruptData); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1248,12 +1124,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; @@ -1469,37 +1344,12 @@ export default function NewChatPage() { } // Merge edited args if present to fix race condition - if (decisions.length > 0 && decisions[0].type === "edit" && decisions[0].edited_action) { - const editedAction = decisions[0].edited_action; - for (const part of contentParts) { - if (part.type === "tool-call" && part.toolName === editedAction.name) { - const mergedArgs = { ...part.args, ...editedAction.args }; - part.args = mergedArgs; - // Sync argsText so the rendered card shows the - // edited inputs — assistant-ui prefers caller- - // supplied argsText over JSON.stringify(args). - part.argsText = JSON.stringify(mergedArgs, null, 2); - break; - } - } + if (decisions.length > 0 && decisions[0].type === "edit") { + mergeEditedInterruptAction(contentParts, decisions[0].edited_action); } const decisionType = decisions[0]?.type as "approve" | "reject" | undefined; - if (decisionType) { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - "__interrupt__" in (part.result as Record<string, unknown>) - ) { - part.result = { - ...(part.result as Record<string, unknown>), - __decided__: decisionType, - }; - } - } - } + markInterruptDecisionOnContentParts(contentParts, decisionType); try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1556,33 +1406,7 @@ export default function NewChatPage() { switch (parsed.type) { case "data-interrupt-request": { const interruptData = parsed.data as Record<string, unknown>; - const actionRequests = (interruptData.action_requests ?? []) as Array<{ - name: string; - args: Record<string, unknown>; - }>; - for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { - result: { - __interrupt__: true, - ...interruptData, - }, - }); - } else { - const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); - updateToolCall(contentPartsState, tcId, { - result: { - __interrupt__: true, - ...interruptData, - }, - }); - } - } + applyInterruptRequestToContentParts(contentPartsState, toolsWithUI, interruptData); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1614,12 +1438,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; @@ -1987,12 +1810,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts new file mode 100644 index 000000000..3026e8203 --- /dev/null +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -0,0 +1,89 @@ +export async function toHttpResponseError( + response: Response +): Promise<Error & { errorCode?: string }> { + const statusDefaultCode = + response.status === 409 + ? "THREAD_BUSY" + : response.status === 429 + ? "RATE_LIMITED" + : response.status === 401 || response.status === 403 + ? "AUTH_EXPIRED" + : "SERVER_ERROR"; + + let rawBody = ""; + try { + rawBody = await response.text(); + } catch { + // noop + } + + let parsedBody: Record<string, unknown> | null = null; + if (rawBody) { + try { + const parsed = JSON.parse(rawBody); + if (typeof parsed === "object" && parsed !== null) { + parsedBody = parsed as Record<string, unknown>; + } + } catch { + // noop + } + } + + const detail = parsedBody?.detail; + const detailObject = + typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null; + const detailMessage = typeof detail === "string" ? detail : undefined; + const topLevelMessage = + typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; + const detailNestedMessage = + typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; + + const topLevelCode = + typeof parsedBody?.errorCode === "string" + ? parsedBody.errorCode + : typeof parsedBody?.error_code === "string" + ? parsedBody.error_code + : undefined; + const detailCode = + typeof detailObject?.errorCode === "string" + ? detailObject.errorCode + : typeof detailObject?.error_code === "string" + ? detailObject.error_code + : undefined; + + const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + const message = + detailNestedMessage ?? + detailMessage ?? + topLevelMessage ?? + `Backend error: ${response.status}`; + + return Object.assign(new Error(message), { errorCode }); +} + +export function tagPreAcceptSendFailure(error: unknown): unknown { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + const existingCode = withCode.errorCode ?? withCode.code; + const passthroughCodes = new Set([ + "PREMIUM_QUOTA_EXHAUSTED", + "THREAD_BUSY", + "AUTH_EXPIRED", + "UNAUTHORIZED", + "RATE_LIMITED", + "NETWORK_ERROR", + "STREAM_PARSE_ERROR", + "TOOL_EXECUTION_ERROR", + "PERSIST_MESSAGE_FAILED", + "SERVER_ERROR", + ]); + if (existingCode && passthroughCodes.has(existingCode)) { + return Object.assign(error, { errorCode: existingCode }); + } + return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); + } + + return Object.assign(new Error("Failed to send message before stream acceptance"), { + errorCode: "SEND_FAILED_PRE_ACCEPT", + }); +} diff --git a/surfsense_web/lib/chat/stream-side-effects.ts b/surfsense_web/lib/chat/stream-side-effects.ts new file mode 100644 index 000000000..9cb349458 --- /dev/null +++ b/surfsense_web/lib/chat/stream-side-effects.ts @@ -0,0 +1,127 @@ +import type { ThreadMessageLike } from "@assistant-ui/react"; +import { + addToolCall, + type ContentPartsState, + type ToolUIGate, + updateToolCall, +} from "@/lib/chat/streaming-state"; + +type InterruptActionRequest = { + name: string; + args: Record<string, unknown>; +}; + +export type EditedInterruptAction = { + name: string; + args: Record<string, unknown>; +}; + +function readInterruptActions( + interruptData: Record<string, unknown> +): InterruptActionRequest[] { + return (interruptData.action_requests ?? []) as InterruptActionRequest[]; +} + +/** + * Applies an interrupt request payload to tool-call parts. Existing tool cards + * are updated in-place; missing ones are upserted so approval UI always shows. + */ +export function applyInterruptRequestToContentParts( + contentPartsState: ContentPartsState, + toolsWithUI: ToolUIGate, + interruptData: Record<string, unknown> +): void { + const { contentParts, toolCallIndices } = contentPartsState; + const actionRequests = readInterruptActions(interruptData); + for (const action of actionRequests) { + const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => { + const part = contentParts[idx]; + return part?.type === "tool-call" && part.toolName === action.name; + }); + + if (existingEntry) { + updateToolCall(contentPartsState, existingEntry[0], { + result: { __interrupt__: true, ...interruptData }, + }); + } else { + const toolCallId = `interrupt-${action.name}`; + addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true); + updateToolCall(contentPartsState, toolCallId, { + result: { __interrupt__: true, ...interruptData }, + }); + } + } +} + +export function mergeEditedInterruptAction( + contentParts: ContentPartsState["contentParts"], + editedAction: EditedInterruptAction | undefined +): void { + if (!editedAction) return; + for (const part of contentParts) { + if (part.type === "tool-call" && part.toolName === editedAction.name) { + const mergedArgs = { ...part.args, ...editedAction.args }; + part.args = mergedArgs; + // assistant-ui prefers argsText over JSON.stringify(args) + part.argsText = JSON.stringify(mergedArgs, null, 2); + break; + } + } +} + +export function markInterruptDecisionOnContentParts( + contentParts: ContentPartsState["contentParts"], + decisionType: "approve" | "reject" | undefined +): void { + if (!decisionType) return; + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + "__interrupt__" in (part.result as Record<string, unknown>) + ) { + part.result = { + ...(part.result as Record<string, unknown>), + __decided__: decisionType, + }; + } + } +} + +/** + * When a streamed message is persisted, the backend returns the durable + * turn_id; merge it into assistant-ui metadata for turn-scoped actions. + */ +export function mergeChatTurnIdIntoMessage( + msg: ThreadMessageLike, + turnId: string | null | undefined +): ThreadMessageLike { + if (!turnId) return msg; + const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; + const existingCustom = existingMeta.custom ?? {}; + if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; + return { + ...msg, + metadata: { + ...existingMeta, + custom: { ...existingCustom, chatTurnId: turnId }, + }, + }; +} + +export function readStreamedChatTurnId(data: unknown): string | null { + if (typeof data !== "object" || data === null) return null; + const value = (data as { chat_turn_id?: unknown }).chat_turn_id; + return typeof value === "string" && value.length > 0 ? value : null; +} + +export function applyTurnIdToAssistantMessageList( + messages: ThreadMessageLike[], + assistantMsgId: string, + turnId: string +): ThreadMessageLike[] { + return messages.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m + ); +} From 4056bd1d6947703652e612ac425dabc3ec3c67da Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 22:37:11 +0530 Subject: [PATCH 250/299] refactor(chat): update resetCurrentThreadAtom to include shareToken and contentType for enhanced report panel state management --- surfsense_web/atoms/chat/current-thread.atom.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/surfsense_web/atoms/chat/current-thread.atom.ts b/surfsense_web/atoms/chat/current-thread.atom.ts index d781df8d2..131c98309 100644 --- a/surfsense_web/atoms/chat/current-thread.atom.ts +++ b/surfsense_web/atoms/chat/current-thread.atom.ts @@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat export const resetCurrentThreadAtom = atom(null, (_, set) => { set(currentThreadAtom, initialState); - set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null }); + set(reportPanelAtom, { + isOpen: false, + reportId: null, + title: null, + wordCount: null, + shareToken: null, + contentType: "markdown", + }); }); /** Target comment ID to scroll to (from URL navigation or inbox click) */ From af66fbf106921822a895536c358f2b1a9b93b7a8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 01:47:52 +0530 Subject: [PATCH 251/299] refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling --- .../agents/new_chat/middleware/busy_mutex.py | 56 ++++- .../app/routes/new_chat_routes.py | 169 ++++++++++++++- surfsense_backend/app/schemas/new_chat.py | 18 ++ .../app/services/new_streaming_service.py | 11 +- .../app/tasks/chat/stream_new_chat.py | 75 ++++++- .../unit/agents/new_chat/test_busy_mutex.py | 30 +++ .../unit/test_stream_new_chat_contract.py | 139 ++++++++++--- .../new-chat/[[...chat_id]]/page.tsx | 194 +++++++++++++----- .../lib/chat/chat-error-classifier.ts | 18 +- surfsense_web/lib/chat/chat-request-errors.ts | 29 ++- surfsense_web/lib/chat/stream-pipeline.ts | 5 + surfsense_web/lib/chat/streaming-state.ts | 8 + 12 files changed, 671 insertions(+), 81 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index c57d85004..d61a56533 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -33,6 +33,7 @@ from __future__ import annotations import asyncio import logging +import time import weakref from typing import Any @@ -58,6 +59,8 @@ class _ThreadLockManager: weakref.WeakValueDictionary() ) self._cancel_events: dict[str, asyncio.Event] = {} + self._cancel_requested_at_ms: dict[str, int] = {} + self._cancel_attempt_count: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -76,14 +79,45 @@ class _ThreadLockManager: def request_cancel(self, thread_id: str) -> bool: event = self._cancel_events.get(thread_id) if event is None: - return False + event = asyncio.Event() + self._cancel_events[thread_id] = event event.set() + now_ms = int(time.time() * 1000) + self._cancel_requested_at_ms[thread_id] = now_ms + self._cancel_attempt_count[thread_id] = ( + self._cancel_attempt_count.get(thread_id, 0) + 1 + ) return True + def is_cancel_requested(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + return bool(event and event.is_set()) + + def cancel_state(self, thread_id: str) -> tuple[int, int] | None: + if not self.is_cancel_requested(thread_id): + return None + attempts = self._cancel_attempt_count.get(thread_id, 1) + requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0) + return attempts, requested_at_ms + def reset(self, thread_id: str) -> None: event = self._cancel_events.get(thread_id) if event is not None: event.clear() + self._cancel_requested_at_ms.pop(thread_id, None) + self._cancel_attempt_count.pop(thread_id, None) + + def end_turn(self, thread_id: str) -> None: + """Best-effort terminal cleanup for a thread turn. + + This is intentionally idempotent and safe to call from outer stream + finally-blocks where middleware teardown might be skipped due to abort + or disconnect edge-cases. + """ + lock = self._locks.get(thread_id) + if lock is not None and lock.locked(): + lock.release() + self.reset(thread_id) # Module-level singleton — process-local but reused across all agent @@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event: def request_cancel(thread_id: str) -> bool: - """Trip the cancel event for ``thread_id``. Returns True if found.""" + """Trip the cancel event for ``thread_id``. Always returns True.""" return manager.request_cancel(thread_id) +def is_cancel_requested(thread_id: str) -> bool: + """Return whether ``thread_id`` currently has a pending cancel signal.""" + return manager.is_cancel_requested(thread_id) + + +def get_cancel_state(thread_id: str) -> tuple[int, int] | None: + """Return ``(attempt_count, requested_at_ms)`` for pending cancel state.""" + return manager.cancel_state(thread_id) + + def reset_cancel(thread_id: str) -> None: """Reset the cancel event for ``thread_id`` (called between turns).""" manager.reset(thread_id) +def end_turn(thread_id: str) -> None: + """Force end-of-turn cleanup for lock + cancel state.""" + manager.end_turn(thread_id) + + class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Block concurrent prompts on the same thread. @@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo __all__ = [ "BusyMutexMiddleware", + "end_turn", "get_cancel_event", + "get_cancel_state", + "is_cancel_requested", "manager", "request_cancel", "reset_cancel", diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index e04cce1b5..28b197ca2 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -15,7 +15,7 @@ import json import logging from datetime import UTC, datetime -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse from sqlalchemy import func, or_ from sqlalchemy.exc import IntegrityError, OperationalError @@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, + manager, + request_cancel, +) from app.config import config from app.db import ( ChatComment, @@ -44,6 +50,7 @@ from app.db import ( ) from app.schemas.new_chat import ( AgentToolInfo, + CancelActiveTurnResponse, LocalFilesystemMountPayload, NewChatMessageRead, NewChatRequest, @@ -60,6 +67,7 @@ from app.schemas.new_chat import ( ThreadListItem, ThreadListResponse, TokenUsageSummary, + TurnStatusResponse, ) from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat @@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import ( _logger = logging.getLogger(__name__) _background_tasks: set[asyncio.Task] = set() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 router = APIRouter() @@ -137,6 +148,72 @@ def _resolve_filesystem_selection( ) +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + """Bounded exponential delay for TURN_CANCELLING retry hints.""" + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def _build_turn_status_payload(thread_id: int) -> dict[str, object]: + lock = manager.lock_for(str(thread_id)) + if not lock.locked(): + return {"status": "idle"} + + if is_cancel_requested(str(thread_id)): + cancel_state = get_cancel_state(str(thread_id)) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms + return { + "status": "cancelling", + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + } + + return {"status": "busy"} + + +def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None: + response.headers["retry-after-ms"] = str(retry_after_ms) + response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000)) + + +def _raise_if_thread_busy_for_start(thread_id: int) -> None: + status_payload = _build_turn_status_payload(thread_id) + status = status_payload["status"] + if status == "idle": + return + if status == "cancelling": + retry_after_ms = int(status_payload.get("retry_after_ms") or 0) + detail = { + "errorCode": "TURN_CANCELLING", + "message": "A previous response is still stopping. Please try again in a moment.", + "retry_after_ms": retry_after_ms if retry_after_ms > 0 else None, + "retry_after_at": status_payload.get("retry_after_at"), + } + headers = ( + { + "retry-after-ms": str(retry_after_ms), + "Retry-After": str(max(1, (retry_after_ms + 999) // 1000)), + } + if retry_after_ms > 0 + else None + ) + raise HTTPException(status_code=409, detail=detail, headers=headers) + + raise HTTPException( + status_code=409, + detail={ + "errorCode": "THREAD_BUSY", + "message": "Another response is still finishing for this thread. Please try again in a moment.", + }, + ) + + def _find_pre_turn_checkpoint_id( checkpoint_tuples: list, *, @@ -1476,6 +1553,7 @@ async def handle_new_chat( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(request.chat_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -1550,6 +1628,93 @@ async def handle_new_chat( ) from None +@router.post( + "/threads/{thread_id}/cancel-active-turn", + response_model=CancelActiveTurnResponse, +) +async def cancel_active_turn( + thread_id: int, + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Signal cancellation for the currently running turn on ``thread_id``.""" + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + if status_payload["status"] == "idle": + return CancelActiveTurnResponse( + status="idle", + error_code="NO_ACTIVE_TURN", + ) + + request_cancel(str(thread_id)) + response.status_code = 202 + updated_payload = _build_turn_status_payload(thread_id) + retry_after_ms = int(updated_payload.get("retry_after_ms") or 0) + retry_after_at = ( + int(updated_payload["retry_after_at"]) + if "retry_after_at" in updated_payload + else None + ) + if retry_after_ms > 0: + _set_retry_after_headers(response, retry_after_ms) + return CancelActiveTurnResponse( + status="cancelling", + error_code="TURN_CANCELLING", + retry_after_ms=retry_after_ms if retry_after_ms > 0 else None, + retry_after_at=retry_after_at, + ) + + +@router.get( + "/threads/{thread_id}/turn-status", + response_model=TurnStatusResponse, +) +async def get_turn_status( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + return TurnStatusResponse( + status=status_payload["status"], # type: ignore[arg-type] + active_turn_id=None, + retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type] + retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type] + ) + + # ============================================================================= # Chat Regeneration Endpoint (Edit/Reload) # ============================================================================= @@ -1605,6 +1770,7 @@ async def regenerate_response( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -2012,6 +2178,7 @@ async def resume_chat( ) await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index c7284e901..ec5eefc07 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -335,6 +335,24 @@ class ResumeRequest(BaseModel): local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None +class CancelActiveTurnResponse(BaseModel): + """Response for canceling an active turn on a chat thread.""" + + status: Literal["cancelling", "idle"] + error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"] + retry_after_ms: int | None = None + retry_after_at: int | None = None + + +class TurnStatusResponse(BaseModel): + """Current turn execution status for a thread.""" + + status: Literal["idle", "busy", "cancelling"] + active_turn_id: str | None = None + retry_after_ms: int | None = None + retry_after_at: int | None = None + + # ============================================================================= # Public Chat Snapshot Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 842481f1c..55129668c 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,7 +565,12 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str, error_code: str | None = None) -> str: + def format_error( + self, + error_text: str, + error_code: str | None = None, + extra: dict[str, object] | None = None, + ) -> str: """ Format an error message. @@ -579,9 +584,11 @@ class VercelStreamingService: Example output: data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - payload: dict[str, str] = {"type": "error", "errorText": error_text} + payload: dict[str, object] = {"type": "error", "errorText": error_text} if error_code: payload["errorCode"] = error_code + if extra: + payload.update(extra) return self._format_sse(payload) # ========================================================================= diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2afa851b5..63c149771 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.middleware.busy_mutex import ( + end_turn, + get_cancel_state, + is_cancel_requested, +) from app.agents.new_chat.middleware.kb_persistence import ( commit_staged_filesystem_state, ) @@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: @@ -401,15 +418,35 @@ def _classify_stream_exception( exc: Exception, *, flow_label: str, -) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: raw = str(exc) if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) return ( "thread_busy", "THREAD_BUSY", "warn", True, "Another response is still finishing for this thread. Please try again in a moment.", + None, ) parsed = _parse_error_payload(raw) @@ -431,6 +468,7 @@ def _classify_stream_exception( "warn", True, "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, ) return ( @@ -439,6 +477,7 @@ def _classify_stream_exception( "error", False, f"Error during {flow_label}: {raw}", + None, ) @@ -470,7 +509,7 @@ def _emit_stream_terminal_error( message=message, extra=extra, ) - return streaming_service.format_error(message, error_code=error_code) + return streaming_service.format_error(message, error_code=error_code, extra=extra) def _legacy_match_lc_id( @@ -2497,6 +2536,7 @@ async def stream_new_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) # Initial thinking step - analyzing the request if mentioned_surfsense_docs: @@ -2805,6 +2845,7 @@ async def stream_new_chat( task.add_done_callback(_background_tasks.discard) # Finish the step and message + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2819,11 +2860,19 @@ async def stream_new_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, @@ -2831,7 +2880,9 @@ async def stream_new_chat( error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2847,6 +2898,10 @@ async def stream_new_chat( # (CancelledError is a BaseException), and the rest of the # finally block — including session.close() — would never run. with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _premium_request_id and _premium_reserved > 0 and user_id: try: @@ -3206,6 +3261,7 @@ async def stream_resume_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -3305,6 +3361,7 @@ async def stream_resume_chat( }, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -3318,23 +3375,37 @@ async def stream_resume_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, error_kind=error_kind, error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: try: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index 0c7bf17f6..c923dc499 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -7,7 +7,9 @@ import pytest from app.agents.new_chat.errors import BusyError from app.agents.new_chat.middleware.busy_mutex import ( BusyMutexMiddleware, + end_turn, get_cancel_event, + is_cancel_requested, manager, request_cancel, reset_cancel, @@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None: def test_reset_cancel_idempotent() -> None: # Should not raise even if event was never created reset_cancel("never-seen") + + +def test_request_cancel_creates_event_for_unseen_thread() -> None: + thread_id = "never-seen-cancel" + reset_cancel(thread_id) + + assert request_cancel(thread_id) is True + assert get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is True + + +@pytest.mark.asyncio +async def test_end_turn_force_clears_lock_and_cancel_state() -> None: + thread_id = "forced-end-turn" + mw = BusyMutexMiddleware() + runtime = _Runtime(thread_id) + + await mw.abefore_agent({}, runtime) + assert manager.lock_for(thread_id).locked() + + request_cancel(thread_id) + assert is_cancel_requested(thread_id) is True + + end_turn(thread_id) + + assert not manager.lock_for(thread_id).locked() + assert not get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is False diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 86ea7edd1..a1345c15c 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -8,6 +8,7 @@ import pytest import app.tasks.chat.stream_new_chat as stream_new_chat_module from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, @@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited(): exc = Exception( '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' ) - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "rate_limited" @@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited(): assert severity == "warn" assert is_expected is True assert "temporarily rate-limited" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy_from_message(): exc = Exception("Thread is busy with another request") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None + + +def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): + thread_id = "thread-cancelling-1" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "TURN_CANCELLING" + assert severity == "info" + assert is_expected is True + assert "stopping" in user_message + assert isinstance(extra, dict) + assert "retry_after_ms" in extra def test_premium_classification_is_error_code_driven(): @@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): def test_network_send_failures_use_unified_retry_toast_message(): classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" classifier_source = classifier_path.read_text(encoding="utf-8") - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + request_errors_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts" ) - page_source = page_path.read_text(encoding="utf-8") + request_errors_source = request_errors_path.read_text(encoding="utf-8") assert '"send_failed_pre_accept"' in classifier_source assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert 'errorCode === "TURN_CANCELLING"' in classifier_source assert "if (withCode.code) return withCode.code;" in classifier_source assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source - assert "tagPreAcceptSendFailure(error)" in page_source - assert "const passthroughCodes = new Set([" in page_source - assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source - assert '"THREAD_BUSY"' in page_source - assert '"AUTH_EXPIRED"' in page_source - assert '"UNAUTHORIZED"' in page_source - assert '"RATE_LIMITED"' in page_source - assert '"NETWORK_ERROR"' in page_source - assert '"STREAM_PARSE_ERROR"' in page_source - assert '"TOOL_EXECUTION_ERROR"' in page_source - assert '"PERSIST_MESSAGE_FAILED"' in page_source - assert '"SERVER_ERROR"' in page_source - assert "passthroughCodes.has(existingCode)" in page_source - assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source - assert 'errorCode: "NETWORK_ERROR"' not in page_source - assert "Failed to start chat. Please try again." not in page_source + assert "const passthroughCodes = new Set([" in request_errors_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source + assert '"THREAD_BUSY"' in request_errors_source + assert '"TURN_CANCELLING"' in request_errors_source + assert '"AUTH_EXPIRED"' in request_errors_source + assert '"UNAUTHORIZED"' in request_errors_source + assert '"RATE_LIMITED"' in request_errors_source + assert '"NETWORK_ERROR"' in request_errors_source + assert '"STREAM_PARSE_ERROR"' in request_errors_source + assert '"TOOL_EXECUTION_ERROR"' in request_errors_source + assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source + assert '"SERVER_ERROR"' in request_errors_source + assert "passthroughCodes.has(existingCode)" in request_errors_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source + assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source + assert "Failed to start chat. Please try again." not in classifier_source def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): @@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows() # New flow persists only when accepted and not already persisted. assert "if (newAccepted && !userPersisted) {" in source + assert "const fetchWithTurnCancellingRetry = useCallback(" in source + assert "computeFallbackTurnCancellingRetryDelay" in source + assert 'withMeta.errorCode === "TURN_CANCELLING"' in source + assert 'withMeta.errorCode === "THREAD_BUSY"' in source + assert "await fetchWithTurnCancellingRetry(() =>" in source + + +def test_cancel_active_turn_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source + assert "response_model=CancelActiveTurnResponse" in source + assert 'status="cancelling",' in source + assert 'error_code="TURN_CANCELLING",' in source + assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source + assert "retry_after_at=" in source + assert 'status="idle",' in source + assert 'error_code="NO_ACTIVE_TURN",' in source + + +def test_turn_status_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source + assert "response_model=TurnStatusResponse" in source + assert "_build_turn_status_payload(thread_id)" in source + assert "Permission.CHATS_READ.value" in source + assert "_raise_if_thread_busy_for_start(" in source + + +def test_turn_cancelling_retry_policy_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source + assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source + assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source + assert "def _compute_turn_cancelling_retry_delay(" in source + assert "retry-after-ms" in source + assert '"Retry-After"' in source + assert '"errorCode": "TURN_CANCELLING"' in source + + +def test_turn_status_sse_contract_exists(): + stream_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/tasks/chat/stream_new_chat.py" + ).read_text(encoding="utf-8") + state_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts" + ).read_text(encoding="utf-8") + pipeline_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts" + ).read_text(encoding="utf-8") + + assert '"turn-status"' in stream_source + assert '"status": "busy"' in stream_source + assert '"status": "idle"' in stream_source + assert "type: \"data-turn-status\"" in state_source + assert "case \"data-turn-status\":" in pipeline_source + assert "end_turn(str(chat_id))" in stream_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 02c2914be..1b25ca431 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -182,6 +182,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { * ``stream_new_chat.py``) keep the JSON from ballooning. */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; +const TURN_CANCELLING_INITIAL_DELAY_MS = 200; +const TURN_CANCELLING_BACKOFF_FACTOR = 2; +const TURN_CANCELLING_MAX_DELAY_MS = 1500; +const RECENT_CANCEL_WINDOW_MS = 5_000; + +function sleep(ms: number): Promise<void> { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function computeFallbackTurnCancellingRetryDelay(attempt: number): number { + const safeAttempt = Math.max(1, attempt); + const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); +} export default function NewChatPage() { const params = useParams(); @@ -193,6 +207,7 @@ export default function NewChatPage() { const [isRunning, setIsRunning] = useState(false); const [tokenUsageStore] = useState(() => createTokenUsageStore()); const abortControllerRef = useRef<AbortController | null>(null); + const recentCancelRequestedAtRef = useRef(0); const [pendingInterrupt, setPendingInterrupt] = useState<{ threadId: number; assistantMsgId: string; @@ -598,6 +613,36 @@ export default function NewChatPage() { [handleChatFailure] ); + const fetchWithTurnCancellingRetry = useCallback( + async (runFetch: () => Promise<Response>) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; + } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = + withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } + + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, + [] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -767,12 +812,39 @@ export default function NewChatPage() { // Cancel ongoing request const cancelRun = useCallback(async () => { + if (threadId) { + const token = getBearerToken(); + if (token) { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + try { + const response = await fetch( + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, + { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + }, + } + ); + if (response.ok) { + const payload = (await response.json()) as { + error_code?: string; + }; + if (payload.error_code === "TURN_CANCELLING") { + recentCancelRequestedAtRef.current = Date.now(); + } + } + } catch (error) { + console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error); + } + } + } if (abortControllerRef.current) { abortControllerRef.current.abort(); abortControllerRef.current = null; } setIsRunning(false); - }, []); + }, [threadId]); // Handle new message from user const onNew = useCallback( @@ -971,29 +1043,33 @@ export default function NewChatPage() { setMentionedDocuments([]); } - const response = await fetch(`${backendUrl}/api/v1/new_chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - chat_id: currentThreadId, - user_query: userQuery.trim(), - search_space_id: searchSpaceId, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - messages: messageHistory, - mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, - mentioned_surfsense_doc_ids: hasSurfsenseDocIds - ? mentionedDocumentIds.surfsense_doc_ids - : undefined, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - ...(userImages.length > 0 ? { user_images: userImages } : {}), - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/new_chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chat_id: currentThreadId, + user_query: userQuery.trim(), + search_space_id: searchSpaceId, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + messages: messageHistory, + mentioned_document_ids: hasDocumentIds + ? mentionedDocumentIds.document_ids + : undefined, + mentioned_surfsense_doc_ids: hasSurfsenseDocIds + ? mentionedDocumentIds.surfsense_doc_ids + : undefined, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + ...(userImages.length > 0 ? { user_images: userImages } : {}), + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1033,6 +1109,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1257,6 +1338,7 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, + fetchWithTurnCancellingRetry, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1354,21 +1436,23 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; const selection = await getAgentFilesystemSelection(searchSpaceId); - const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - search_space_id: searchSpaceId, - decisions, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + decisions, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1399,6 +1483,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, }) ) { return; @@ -1496,6 +1585,7 @@ export default function NewChatPage() { searchSpaceId, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, ] @@ -1700,15 +1790,17 @@ export default function NewChatPage() { requestBody.revert_actions = true; } } - const response = await fetch(getRegenerateUrl(threadId), { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(requestBody), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(getRegenerateUrl(threadId), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify(requestBody), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1774,6 +1866,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1945,6 +2042,7 @@ export default function NewChatPage() { setMessageDocumentsMap, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 57341a4c3..7dfbfc1a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if ( + errorCode === "TURN_CANCELLING" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "TURN_CANCELLING", + details: { flow: input.flow }, + }; + } + if ( errorCode === "THREAD_BUSY" ) { @@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError severity: "warn", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: "A previous response is still stopping. Please try again in a moment.", + userMessage: "Another response is still finishing for this thread. Please try again in a moment.", rawMessage, errorCode: errorCode ?? "THREAD_BUSY", details: { flow: input.flow }, diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index 3026e8203..708831354 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -1,6 +1,6 @@ export async function toHttpResponseError( response: Response -): Promise<Error & { errorCode?: string }> { +): Promise<Error & { errorCode?: string; retryAfterMs?: number }> { const statusDefaultCode = response.status === 409 ? "THREAD_BUSY" @@ -52,13 +52,37 @@ export async function toHttpResponseError( : undefined; const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + + const detailRetryAfterMs = + typeof detailObject?.retry_after_ms === "number" + ? detailObject.retry_after_ms + : typeof detailObject?.retryAfterMs === "number" + ? detailObject.retryAfterMs + : undefined; + const topRetryAfterMs = + typeof parsedBody?.retry_after_ms === "number" + ? parsedBody.retry_after_ms + : typeof parsedBody?.retryAfterMs === "number" + ? parsedBody.retryAfterMs + : undefined; + const headerRetryAfterMsRaw = response.headers.get("retry-after-ms"); + const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN; + const retryAfterHeader = response.headers.get("retry-after"); + const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN; + const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs) + ? Math.max(0, Math.round(headerRetryAfterMs)) + : Number.isFinite(retryAfterSeconds) + ? Math.max(0, Math.round(retryAfterSeconds * 1000)) + : undefined; + const retryAfterMs = + detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; const message = detailNestedMessage ?? detailMessage ?? topLevelMessage ?? `Backend error: ${response.status}`; - return Object.assign(new Error(message), { errorCode }); + return Object.assign(new Error(message), { errorCode, retryAfterMs }); } export function tagPreAcceptSendFailure(error: unknown): unknown { @@ -68,6 +92,7 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { const passthroughCodes = new Set([ "PREMIUM_QUOTA_EXHAUSTED", "THREAD_BUSY", + "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", "RATE_LIMITED", diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts index 8957bdea3..c9118f949 100644 --- a/surfsense_web/lib/chat/stream-pipeline.ts +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -21,6 +21,7 @@ export type SharedStreamEventContext = { scheduleFlush: () => void; forceFlush: () => void; onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void; + onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void; onToolOutputAvailable?: ( event: Extract<SSEEvent, { type: "tool-output-available" }>, context: { @@ -173,6 +174,10 @@ export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStream context.onTokenUsage?.(parsed.data); return true; + case "data-turn-status": + context.onTurnStatus?.(parsed.data); + return true; + case "error": throw toStreamTerminalError(parsed); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 445bbe83d..80e7bffbe 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -528,6 +528,14 @@ export type SSEEvent = }>; }; } + | { + type: "data-turn-status"; + data: { + status: "idle" | "busy" | "cancelling"; + retry_after_ms?: number; + retry_after_at?: number; + }; + } | { type: "data-token-usage"; data: { From a66c1576b965acc50ae89d8a0f71ed3db1b64077 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:09:53 +0530 Subject: [PATCH 252/299] refactor(chat): introduce ChatViewport and NestedScroll components for improved chat UI structure and functionality --- .../components/assistant-ui/chat-viewport.tsx | 44 +++++++ .../components/assistant-ui/nested-scroll.tsx | 24 ++++ .../assistant-ui/thread-scroll-to-bottom.tsx | 18 --- .../components/assistant-ui/thread.tsx | 108 +++--------------- .../components/assistant-ui/tool-fallback.tsx | 9 +- .../components/free-chat/free-thread.tsx | 43 ++----- .../components/public-chat/public-thread.tsx | 9 +- 7 files changed, 99 insertions(+), 156 deletions(-) create mode 100644 surfsense_web/components/assistant-ui/chat-viewport.tsx create mode 100644 surfsense_web/components/assistant-ui/nested-scroll.tsx delete mode 100644 surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx new file mode 100644 index 000000000..f91a8916a --- /dev/null +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -0,0 +1,44 @@ +"use client"; + +import { ThreadPrimitive } from "@assistant-ui/react"; +import { ArrowDownIcon } from "lucide-react"; +import type { FC, ReactNode } from "react"; +import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; + +const ChatScrollToBottom: FC = () => ( + <ThreadPrimitive.ScrollToBottom asChild> + <TooltipIconButton + tooltip="Scroll to bottom" + variant="outline" + className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" + > + <ArrowDownIcon /> + </TooltipIconButton> + </ThreadPrimitive.ScrollToBottom> +); + +export interface ChatViewportProps { + children: ReactNode; + footer?: ReactNode; +} + +export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( + <ThreadPrimitive.Viewport + scrollToBottomOnRunStart + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" + style={{ scrollbarGutter: "stable" }} + > + {children} + {footer ? ( + <ThreadPrimitive.ViewportFooter + className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" + style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} + > + <ChatScrollToBottom /> + {footer} + </ThreadPrimitive.ViewportFooter> + ) : null} + </ThreadPrimitive.Viewport> +); diff --git a/surfsense_web/components/assistant-ui/nested-scroll.tsx b/surfsense_web/components/assistant-ui/nested-scroll.tsx new file mode 100644 index 000000000..5a4f8d36e --- /dev/null +++ b/surfsense_web/components/assistant-ui/nested-scroll.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react"; + +export type NestedScrollProps = ComponentPropsWithoutRef<"div">; + +export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>( + ({ onWheel, ...props }, ref) => { + const handleWheel = (event: WheelEvent<HTMLDivElement>) => { + const el = event.currentTarget; + const canScrollUp = el.scrollTop > 0; + const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1; + const goingUp = event.deltaY < 0; + const goingDown = event.deltaY > 0; + if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) { + event.stopPropagation(); + } + onWheel?.(event); + }; + return <div ref={ref} onWheel={handleWheel} {...props} />; + } +); + +NestedScroll.displayName = "NestedScroll"; diff --git a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx b/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx deleted file mode 100644 index 394ba5d79..000000000 --- a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import { ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; -import type { FC } from "react"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; - -export const ThreadScrollToBottom: FC = () => { - return ( - <ThreadPrimitive.ScrollToBottom asChild> - <TooltipIconButton - tooltip="Scroll to bottom" - variant="outline" - className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" - > - <ArrowDownIcon /> - </TooltipIconButton> - </ThreadPrimitive.ScrollToBottom> - ); -}; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 3e27e7adb..1d24a2a39 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -5,12 +5,10 @@ import { ThreadPrimitive, useAui, useAuiState, - useThreadViewportStore, } from "@assistant-ui/react"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { AlertCircle, - ArrowDownIcon, ArrowUpIcon, Camera, ChevronDown, @@ -55,6 +53,7 @@ import { import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; import { @@ -112,10 +111,17 @@ const ThreadContent: FC = () => { ["--thread-max-width" as string]: "44rem", }} > - <ThreadPrimitive.Viewport - turnAnchor="top" - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} + <ChatViewport + footer={ + <> + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <PremiumQuotaPinnedAlert /> + </AuiIf> + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <Composer /> + </AuiIf> + </> + } > <AuiIf condition={({ thread }) => thread.isEmpty}> <ThreadWelcome /> @@ -128,24 +134,7 @@ const ThreadContent: FC = () => { AssistantMessage, }} /> - - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <div className="grow" /> - </AuiIf> - - <ThreadPrimitive.ViewportFooter - className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" - style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} - > - <ThreadScrollToBottom /> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <PremiumQuotaPinnedAlert /> - </AuiIf> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <Composer /> - </AuiIf> - </ThreadPrimitive.ViewportFooter> - </ThreadPrimitive.Viewport> + </ChatViewport> </ThreadPrimitive.Root> ); }; @@ -181,20 +170,6 @@ const PremiumQuotaPinnedAlert: FC = () => { ); }; -const ThreadScrollToBottom: FC = () => { - return ( - <ThreadPrimitive.ScrollToBottom asChild> - <TooltipIconButton - tooltip="Scroll to bottom" - variant="outline" - className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" - > - <ArrowDownIcon /> - </TooltipIconButton> - </ThreadPrimitive.ScrollToBottom> - ); -}; - const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => { const hour = new Date().getHours(); @@ -411,23 +386,9 @@ const Composer: FC = () => { >(new Map()); const documentPickerRef = useRef<DocumentMentionPickerRef>(null); const promptPickerRef = useRef<PromptPickerRef>(null); - const viewportRef = useRef<Element | null>(null); const { search_space_id, chat_id } = useParams(); const aui = useAui(); - const threadViewportStore = useThreadViewportStore(); const hasAutoFocusedRef = useRef(false); - const submitCleanupRef = useRef<(() => void) | null>(null); - - useEffect(() => { - return () => { - submitCleanupRef.current?.(); - }; - }, []); - - // Store viewport element reference on mount - useEffect(() => { - viewportRef.current = document.querySelector(".aui-thread-viewport"); - }, []); const electronAPI = useElectronAPI(); const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>(); @@ -626,7 +587,6 @@ const Composer: FC = () => { [showDocumentPopover, showPromptPicker] ); - // Submit message (blocked during streaming, document picker open, or AI responding to another user) const handleSubmit = useCallback(() => { if (isThreadRunning || isBlockedByOtherUser) return; if (showDocumentPopover || showPromptPicker) return; @@ -638,50 +598,9 @@ const Composer: FC = () => { setClipboardInitialText(undefined); } - const viewportEl = viewportRef.current; - const heightBefore = viewportEl?.scrollHeight ?? 0; - aui.composer().send(); editorRef.current?.clear(); setMentionedDocuments([]); - - // With turnAnchor="top", ViewportSlack adds min-height to the last - // assistant message so that scrolling-to-bottom actually positions the - // user message at the TOP of the viewport. That slack height is - // calculated asynchronously (ResizeObserver → style → layout). - // Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes. - const scrollToBottom = () => - threadViewportStore.getState().scrollToBottom({ behavior: "instant" }); - - let lastHeight = heightBefore; - let frames = 0; - let cancelled = false; - const POLL_FRAMES = 30; - - const pollAndScroll = () => { - if (cancelled) return; - const el = viewportRef.current; - if (el) { - const h = el.scrollHeight; - if (h !== lastHeight) { - lastHeight = h; - scrollToBottom(); - } - } - if (++frames < POLL_FRAMES) { - requestAnimationFrame(pollAndScroll); - } - }; - requestAnimationFrame(pollAndScroll); - - const t1 = setTimeout(scrollToBottom, 100); - const t2 = setTimeout(scrollToBottom, 300); - - submitCleanupRef.current = () => { - cancelled = true; - clearTimeout(t1); - clearTimeout(t2); - }; }, [ showDocumentPopover, showPromptPicker, @@ -690,7 +609,6 @@ const Composer: FC = () => { clipboardInitialText, aui, setMentionedDocuments, - threadViewportStore, ]); const handleDocumentRemove = useCallback( diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 66e2ebd4a..cf42cf398 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -13,6 +13,7 @@ import { isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; +import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { AlertDialog, AlertDialogAction, @@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { {(argsText || isRunning) && ( <div className="flex flex-col gap-1 min-w-0"> <p className="text-xs font-medium text-muted-foreground">Inputs</p> - <div className="max-h-48 overflow-auto rounded-md bg-muted/40"> + <NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40"> {argsText ? ( <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> {argsText} @@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { Waiting for input… </p> )} - </div> + </NestedScroll> </div> )} {!isCancelled && result !== undefined && ( @@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { <Separator /> <div className="flex flex-col gap-1 min-w-0"> <p className="text-xs font-medium text-muted-foreground">Result</p> - <div className="max-h-64 overflow-auto rounded-md bg-muted/40"> + <NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40"> <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> {typeof result === "string" ? result : serializedResult} </pre> - </div> + </NestedScroll> </div> </> )} diff --git a/surfsense_web/components/free-chat/free-thread.tsx b/surfsense_web/components/free-chat/free-thread.tsx index bd237004a..933847b2b 100644 --- a/surfsense_web/components/free-chat/free-thread.tsx +++ b/surfsense_web/components/free-chat/free-thread.tsx @@ -1,11 +1,10 @@ "use client"; import { AuiIf, ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; import type { FC } from "react"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { EditComposer } from "@/components/assistant-ui/edit-composer"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { UserMessage } from "@/components/assistant-ui/user-message"; import { FreeComposer } from "./free-composer"; @@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => { ); }; -const ThreadScrollToBottom: FC = () => { - return ( - <ThreadPrimitive.ScrollToBottom asChild> - <TooltipIconButton - tooltip="Scroll to bottom" - variant="outline" - className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" - > - <ArrowDownIcon /> - </TooltipIconButton> - </ThreadPrimitive.ScrollToBottom> - ); -}; - export const FreeThread: FC = () => { return ( <ThreadPrimitive.Root @@ -46,10 +31,12 @@ export const FreeThread: FC = () => { ["--thread-max-width" as string]: "44rem", }} > - <ThreadPrimitive.Viewport - turnAnchor="top" - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} + <ChatViewport + footer={ + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <FreeComposer /> + </AuiIf> + } > <AuiIf condition={({ thread }) => thread.isEmpty}> <FreeThreadWelcome /> @@ -62,21 +49,7 @@ export const FreeThread: FC = () => { AssistantMessage, }} /> - - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <div className="grow" /> - </AuiIf> - - <ThreadPrimitive.ViewportFooter - className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" - style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} - > - <ThreadScrollToBottom /> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <FreeComposer /> - </AuiIf> - </ThreadPrimitive.ViewportFooter> - </ThreadPrimitive.Viewport> + </ChatViewport> </ThreadPrimitive.Root> ); }; diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 22e914988..de91b4451 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -45,16 +45,17 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => { ["--thread-max-width" as string]: "44rem", }} > - <ThreadPrimitive.Viewport className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"> + <ThreadPrimitive.Viewport + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 pb-6" + > <ThreadPrimitive.Messages components={{ UserMessage: PublicUserMessage, AssistantMessage: PublicAssistantMessage, }} /> - - {/* Spacer to ensure footer doesn't overlap last message */} - <div className="h-24" /> </ThreadPrimitive.Viewport> {footer && ( From 833b4dd441d0e8053bd2399076fedcf067917617 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:10:21 +0530 Subject: [PATCH 253/299] refactor(chat): simplify ChatViewport and footer structure for improved readability and maintainability --- .../components/assistant-ui/chat-viewport.tsx | 26 ++++++++++--------- .../components/assistant-ui/thread.tsx | 12 +++------ .../components/public-chat/public-thread.tsx | 2 +- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index f91a8916a..a1534df01 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -23,22 +23,24 @@ export interface ChatViewportProps { } export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( - <ThreadPrimitive.Viewport - scrollToBottomOnRunStart - scrollToBottomOnInitialize - scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} - > - {children} + <> + <ThreadPrimitive.Viewport + scrollToBottomOnRunStart + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" + style={{ scrollbarGutter: "stable" }} + > + {children} + </ThreadPrimitive.Viewport> {footer ? ( - <ThreadPrimitive.ViewportFooter - className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" + <div + className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible bg-main-panel px-4 pt-2 pb-4 md:pb-6" style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} > <ChatScrollToBottom /> {footer} - </ThreadPrimitive.ViewportFooter> + </div> ) : null} - </ThreadPrimitive.Viewport> + </> ); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 1d24a2a39..6c02a1efa 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -113,14 +113,10 @@ const ThreadContent: FC = () => { > <ChatViewport footer={ - <> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <PremiumQuotaPinnedAlert /> - </AuiIf> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <Composer /> - </AuiIf> - </> + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <PremiumQuotaPinnedAlert /> + <Composer /> + </AuiIf> } > <AuiIf condition={({ thread }) => thread.isEmpty}> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index de91b4451..750b7410e 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -59,7 +59,7 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => { </ThreadPrimitive.Viewport> {footer && ( - <div className="sticky bottom-0 z-20 border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60"> + <div className="border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60"> {footer} </div> )} From b2f487bf36829b3ceb61c654ab6557561a3483ac Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 15:03:10 -0700 Subject: [PATCH 254/299] feat: added mac signing & notarization for desktop app --- .github/workflows/desktop-release.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index b955e5014..e356bd3e5 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -136,6 +136,14 @@ jobs: AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }} AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }} AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }} + # macOS Developer ID signing + notarization. Only the macos-latest runner + # consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either + # a file path or a base64-encoded .p12 blob — electron-builder auto-detects. + CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }} + CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }} + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} # Service principal credentials for Azure.Identity EnvironmentCredential used by the # TrustedSigning PowerShell module. Only populated when signing is enabled. # electron-builder 26 does not yet support OIDC federated tokens for Azure signing, From 7b549f84445ef158b97d0270143a88c623d89ab7 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:38:21 +0530 Subject: [PATCH 255/299] refactor(chat): enhance ChatViewport with auto-scroll and top fade effect for improved user experience --- .../components/assistant-ui/chat-viewport.tsx | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index a1534df01..d3d664ace 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -23,24 +23,30 @@ export interface ChatViewportProps { } export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( - <> - <ThreadPrimitive.Viewport - scrollToBottomOnRunStart - scrollToBottomOnInitialize - scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} - > - {children} - </ThreadPrimitive.Viewport> + <ThreadPrimitive.Viewport + turnAnchor="top" + autoScroll + scrollToBottomOnRunStart + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 [scroll-behavior:smooth]" + style={{ scrollbarGutter: "stable" }} + > + <div + aria-hidden + className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent" + /> + {children} {footer ? ( - <div - className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible bg-main-panel px-4 pt-2 pb-4 md:pb-6" - style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} + <ThreadPrimitive.ViewportFooter + className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6" + style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }} > - <ChatScrollToBottom /> - {footer} - </div> + <div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible"> + <ChatScrollToBottom /> + {footer} + </div> + </ThreadPrimitive.ViewportFooter> ) : null} - </> + </ThreadPrimitive.Viewport> ); From 511f4fde6440378a111fb7bdc3f84cbf4b9c85c1 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:40:14 +0530 Subject: [PATCH 256/299] refactor(chat): update ChatViewport className for improved scroll behavior consistency --- surfsense_web/components/assistant-ui/chat-viewport.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index d3d664ace..f7f1ac188 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -29,7 +29,7 @@ export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( scrollToBottomOnRunStart scrollToBottomOnInitialize scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 [scroll-behavior:smooth]" + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth" style={{ scrollbarGutter: "stable" }} > <div From 8b4f1366684e69cfb403ec9942a7ac0d2cc677d9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:02:24 +0530 Subject: [PATCH 257/299] refactor(chat): enhance UserMessage component with mention parsing and segment rendering for improved message display --- .../components/assistant-ui/chat-viewport.tsx | 2 +- .../components/assistant-ui/user-message.tsx | 121 ++++++------------ .../lib/chat/parse-mention-segments.ts | 54 ++++++++ 3 files changed, 97 insertions(+), 80 deletions(-) create mode 100644 surfsense_web/lib/chat/parse-mention-segments.ts diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index f7f1ac188..c0684407e 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -39,7 +39,7 @@ export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( {children} {footer ? ( <ThreadPrimitive.ViewportFooter - className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6" + className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6" style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }} > <div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible"> diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index fb7212119..145ac2d7e 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,4 +1,10 @@ -import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react"; +import { + ActionBarPrimitive, + AuiIf, + MessagePrimitive, + useAuiState, + useMessagePartText, +} from "@assistant-ui/react"; import { useAtomValue } from "jotai"; import { CheckIcon, CopyIcon, Pencil } from "lucide-react"; import Image from "next/image"; @@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; +import { parseMentionSegments } from "@/lib/chat/parse-mention-segments"; interface AuthorMetadata { displayName: string | null; @@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => { ); }; -export const UserMessage: FC = () => { +const UserTextPart: FC = () => { const messageId = useAuiState(({ message }) => message?.id); - const messageText = useAuiState(({ message }) => - (message?.content ?? []) - .map((part) => - typeof part === "object" && - part !== null && - "type" in part && - (part as { type?: string }).type === "text" && - "text" in part - ? String((part as { text?: string }).text ?? "") - : "" - ) - .join("") - ); + const part = useMessagePartText(); + const text = (part as { text?: string }).text ?? ""; const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); - const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; + const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? []; + + const segments = parseMentionSegments(text, mentionedDocs); + + return ( + <p style={{ whiteSpace: "pre-line" }} className="break-words"> + {segments.map((segment) => + segment.type === "text" ? ( + <span key={`txt-${segment.start}`}>{segment.value}</span> + ) : ( + <span + key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`} + className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none" + title={segment.doc.title} + > + <span className="flex items-center text-muted-foreground"> + {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} + </span> + <span className="max-w-[120px] truncate">{segment.doc.title}</span> + </span> + ) + )} + </p> + ); +}; + +const userMessageParts = { Text: UserTextPart }; + +export const UserMessage: FC = () => { const metadata = useAuiState(({ message }) => message?.metadata); const author = metadata?.custom?.author as AuthorMetadata | undefined; const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE"; @@ -78,11 +103,7 @@ export const UserMessage: FC = () => { <div className="aui-user-message-content-wrapper flex items-end gap-2"> <div className="relative flex-1 min-w-0"> <div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground"> - {mentionedDocs && mentionedDocs.length > 0 ? ( - <UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} /> - ) : ( - <MessagePrimitive.Parts /> - )} + <MessagePrimitive.Parts components={userMessageParts} /> </div> <div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto"> <UserActionBar /> @@ -99,64 +120,6 @@ export const UserMessage: FC = () => { ); }; -const UserMessageWithMentionChips: FC<{ - text: string; - mentionedDocs: { id: number; title: string; document_type: string }[]; -}> = ({ text, mentionedDocs }) => { - type Segment = - | { type: "text"; value: string; start: number } - | { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number }; - - const tokens = mentionedDocs - .map((doc) => ({ doc, token: `@${doc.title}` })) - .sort((a, b) => b.token.length - a.token.length); - - const segments: Segment[] = []; - let i = 0; - let buffer = ""; - let bufferStart = 0; - while (i < text.length) { - const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); - if (tokenMatch) { - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - buffer = ""; - } - segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); - i += tokenMatch.token.length; - bufferStart = i; - continue; - } - if (!buffer) bufferStart = i; - buffer += text[i]; - i += 1; - } - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - } - - return ( - <span className="whitespace-pre-wrap break-words"> - {segments.map((segment) => - segment.type === "text" ? ( - <span key={`txt-${segment.start}`}>{segment.value}</span> - ) : ( - <span - key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`} - className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline" - title={segment.doc.title} - > - <span className="flex items-center text-muted-foreground"> - {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} - </span> - <span className="max-w-[120px] truncate">{segment.doc.title}</span> - </span> - ) - )} - </span> - ); -}; - const UserActionBar: FC = () => { const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); diff --git a/surfsense_web/lib/chat/parse-mention-segments.ts b/surfsense_web/lib/chat/parse-mention-segments.ts new file mode 100644 index 000000000..b9cf59792 --- /dev/null +++ b/surfsense_web/lib/chat/parse-mention-segments.ts @@ -0,0 +1,54 @@ +import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom"; + +export type MentionSegment = + | { type: "text"; value: string; start: number } + | { type: "mention"; doc: MentionedDocumentInfo; start: number }; + +/** + * Tokenizes a user message into text and `@mention` segments. + * + * Pure: no React, no DOM, no side effects. Safe to unit-test and reuse. + * + * Mentions are matched greedily by longest title first so that a longer title + * (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix + * (e.g. `@Project`). + */ +export function parseMentionSegments( + text: string, + docs: ReadonlyArray<MentionedDocumentInfo> +): MentionSegment[] { + if (text.length === 0) return []; + if (docs.length === 0) return [{ type: "text", value: text, start: 0 }]; + + const tokens = docs + .map((doc) => ({ doc, token: `@${doc.title}` })) + .sort((a, b) => b.token.length - a.token.length); + + const segments: MentionSegment[] = []; + let i = 0; + let buffer = ""; + let bufferStart = 0; + + while (i < text.length) { + const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); + if (tokenMatch) { + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + buffer = ""; + } + segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); + i += tokenMatch.token.length; + bufferStart = i; + continue; + } + if (!buffer) bufferStart = i; + buffer += text[i]; + i += 1; + } + + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + } + + return segments; +} From 3a73912a86f9c7fac85cc1bb1fd228563874d385 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 15:39:12 -0700 Subject: [PATCH 258/299] feat(desktop): enable hardened runtime and entitlements for mac signing Made-with: Cursor --- .../build/entitlements.mac.plist | 35 +++++++++++++++++++ surfsense_desktop/electron-builder.yml | 5 ++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 surfsense_desktop/build/entitlements.mac.plist diff --git a/surfsense_desktop/build/entitlements.mac.plist b/surfsense_desktop/build/entitlements.mac.plist new file mode 100644 index 000000000..5647e7759 --- /dev/null +++ b/surfsense_desktop/build/entitlements.mac.plist @@ -0,0 +1,35 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> +<plist version="1.0"> +<dict> + <!-- Required for Electron's V8 JIT under hardened runtime --> + <key>com.apple.security.cs.allow-jit</key> + <true/> + <key>com.apple.security.cs.allow-unsigned-executable-memory</key> + <true/> + + <!-- node-mac-permissions and other native deps load dylibs at runtime --> + <key>com.apple.security.cs.allow-dyld-environment-variables</key> + <true/> + <key>com.apple.security.cs.disable-library-validation</key> + <true/> + + <!-- Networking (OAuth, API calls, auto-updater, deep links) --> + <key>com.apple.security.network.client</key> + <true/> + <key>com.apple.security.network.server</key> + <true/> + + <!-- Screen Capture / Screenshot Assist --> + <key>com.apple.security.device.camera</key> + <true/> + + <!-- Accessibility / Apple Events used by general-assist --> + <key>com.apple.security.automation.apple-events</key> + <true/> + + <!-- File access for folder watcher / agent filesystem features --> + <key>com.apple.security.files.user-selected.read-write</key> + <true/> +</dict> +</plist> diff --git a/surfsense_desktop/electron-builder.yml b/surfsense_desktop/electron-builder.yml index b0014a57b..e4e7670ec 100644 --- a/surfsense_desktop/electron-builder.yml +++ b/surfsense_desktop/electron-builder.yml @@ -46,8 +46,11 @@ mac: icon: assets/icon.icns category: public.app-category.productivity artifactName: "${productName}-${version}-${arch}.${ext}" - hardenedRuntime: false + hardenedRuntime: true gatekeeperAssess: false + entitlements: build/entitlements.mac.plist + entitlementsInherit: build/entitlements.mac.plist + notarize: true extendInfo: NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists." NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer." From 0883ac88fb54653223ff2477724d372531fa1301 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:23:59 +0530 Subject: [PATCH 259/299] refactor(chat): enhance InlineMentionEditor with improved mention handling and text processing for better user interaction --- .../assistant-ui/inline-mention-editor.tsx | 1078 ++++++----------- 1 file changed, 391 insertions(+), 687 deletions(-) diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index 05277f508..d92348080 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,26 +1,13 @@ "use client"; -import { X } from "lucide-react"; -import type { ReactElement } from "react"; -import { - createElement, - forwardRef, - useCallback, - useEffect, - useImperativeHandle, - useRef, - useState, -} from "react"; -import { renderToStaticMarkup } from "react-dom/server"; +import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; +import { Plate, PlateContent, ParagraphPlugin, createPlatePlugin, usePlateEditor } from "platejs/react"; +import type { PlateElementProps } from "platejs/react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { cn } from "@/lib/utils"; -function renderElementToHTML(element: ReactElement): string { - return renderToStaticMarkup(element); -} - export interface MentionedDocument { id: number; title: string; @@ -61,38 +48,174 @@ interface InlineMentionEditorProps { initialText?: string; } -// Unique data attribute to identify chip elements -const CHIP_DATA_ATTR = "data-mention-chip"; -const CHIP_ID_ATTR = "data-mention-id"; -const CHIP_DOCTYPE_ATTR = "data-mention-doctype"; -const CHIP_STATUS_ATTR = "data-mention-status"; +type MentionStatusKind = "pending" | "processing" | "ready" | "failed"; +type ComposerTextNode = { text: string }; +type MentionElementNode = { + type: "mention"; + id: number; + title: string; + document_type?: string; + statusLabel?: string | null; + statusKind?: MentionStatusKind; + children: [{ text: "" }]; +}; +type ComposerNode = ComposerTextNode | MentionElementNode; +type ComposerParagraph = { type: "p"; children: ComposerNode[] }; +type ComposerValue = ComposerParagraph[]; + +const MENTION_TYPE = "mention"; +const MENTION_CHIP_CLASSNAME = + "inline-flex h-5 items-center gap-1 mx-0.5 rounded bg-primary/10 px-1 text-xs font-bold text-primary/60 select-none align-middle leading-none"; +const MENTION_CHIP_ICON_CLASSNAME = "flex items-center text-muted-foreground leading-none"; +const MENTION_CHIP_TITLE_CLASSNAME = "max-w-[120px] truncate leading-none"; +const COMPOSER_TEXT_METRICS_CLASSNAME = "text-sm leading-6"; + +const EMPTY_VALUE: ComposerValue = [{ type: "p", children: [{ text: "" }] }]; + +const MentionElement: FC<PlateElementProps<MentionElementNode>> = ({ attributes, children, element }) => { + const statusClass = + element.statusKind === "failed" + ? "text-destructive" + : element.statusKind === "ready" + ? "text-emerald-700" + : "text-amber-700"; -/** - * Type guard to check if a node is a chip element - */ -function isChipElement(node: Node | null): node is HTMLSpanElement { return ( - node !== null && - node.nodeType === Node.ELEMENT_NODE && - (node as Element).hasAttribute(CHIP_DATA_ATTR) + <span {...attributes} className="inline-flex align-middle"> + <span contentEditable={false} className={`${MENTION_CHIP_CLASSNAME} cursor-default`}> + <span className={MENTION_CHIP_ICON_CLASSNAME}> + {getConnectorIcon(element.document_type ?? "UNKNOWN", "h-3 w-3")} + </span> + <span className={MENTION_CHIP_TITLE_CLASSNAME} title={element.title}> + {element.title} + </span> + {element.statusLabel ? ( + <span className={cn("text-[10px] font-semibold opacity-80", statusClass)}> + {element.statusLabel} + </span> + ) : null} + </span> + {children} + </span> ); +}; + +const MentionPlugin = createPlatePlugin({ + key: MENTION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: MENTION_TYPE, + component: MentionElement, + }, +}); + +function isMentionNode(node: ComposerNode): node is MentionElementNode { + return typeof node === "object" && "type" in node && node.type === MENTION_TYPE; } -/** - * Safely parse chip ID from element attribute - */ -function getChipId(element: Element): number | null { - const idStr = element.getAttribute(CHIP_ID_ATTR); - if (!idStr) return null; - const id = parseInt(idStr, 10); - return Number.isNaN(id) ? null : id; +function getTextNode(node: ComposerNode): ComposerTextNode | null { + if (typeof node === "object" && "text" in node && typeof node.text === "string") return node; + return null; } -/** - * Get chip document type from element attribute - */ -function getChipDocType(element: Element): string { - return element.getAttribute(CHIP_DOCTYPE_ATTR) ?? "UNKNOWN"; +function toValueFromText(text: string): ComposerValue { + const lines = text.split("\n"); + if (lines.length === 0) return EMPTY_VALUE; + return lines.map((line) => ({ type: "p", children: [{ text: line }] })) as ComposerValue; +} + +function getPlainText(value: ComposerValue): string { + const lines = value.map((block) => + block.children + .map((node) => { + if (isMentionNode(node)) return `@${node.title}`; + return getTextNode(node)?.text ?? ""; + }) + .join("") + ); + return lines.join("\n").trim(); +} + +function getMentionedDocuments(value: ComposerValue): MentionedDocument[] { + const map = new Map<string, MentionedDocument>(); + for (const block of value) { + for (const node of block.children) { + if (!isMentionNode(node)) continue; + const doc: MentionedDocument = { + id: node.id, + title: node.title, + document_type: node.document_type, + }; + map.set(getMentionDocKey(doc), doc); + } + } + return Array.from(map.values()); +} + +type EditorSelection = { + anchor: { path: number[]; offset: number }; + focus: { path: number[]; offset: number }; +} | null; + +function getCursorTextContext(value: ComposerValue, selection: EditorSelection) { + if (!selection || !selection.anchor || !selection.focus) return null; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] || + selection.anchor.path[1] !== selection.focus.path[1] + ) { + return null; + } + + const block = value[selection.anchor.path[0]]; + if (!block) return null; + const child = block.children[selection.anchor.path[1]]; + const textNode = getTextNode(child); + if (!textNode) return null; + + return { + blockIndex: selection.anchor.path[0], + childIndex: selection.anchor.path[1], + text: textNode.text, + cursor: selection.anchor.offset, + }; +} + +function scanActiveTrigger(text: string, cursor: number) { + let wordStart = 0; + for (let i = cursor - 1; i >= 0; i--) { + if (text[i] === " " || text[i] === "\n") { + wordStart = i + 1; + break; + } + } + + let triggerChar: "@" | "/" | null = null; + let triggerIndex = -1; + for (let i = wordStart; i < cursor; i++) { + if (text[i] === "@" || text[i] === "/") { + triggerChar = text[i] as "@" | "/"; + triggerIndex = i; + break; + } + } + if (!triggerChar || triggerIndex === -1) return null; + + const query = text.slice(triggerIndex + 1, cursor); + if (query.startsWith(" ")) return null; + if ( + triggerChar === "/" && + triggerIndex > 0 && + text[triggerIndex - 1] !== " " && + text[triggerIndex - 1] !== "\n" + ) { + return null; + } + + return { triggerChar, query }; } export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMentionEditorProps>( @@ -113,393 +236,159 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent }, ref ) => { - const editorRef = useRef<HTMLDivElement>(null); - const [isEmpty, setIsEmpty] = useState(true); - const [mentionedDocs, setMentionedDocs] = useState<Map<string, MentionedDocument>>( - () => new Map() - ); - const isComposingRef = useRef(false); - const lastSelectionRangeRef = useRef<Range | null>(null); - const isRangeInsideEditor = useCallback((range: Range | null): range is Range => { - if (!range || !editorRef.current) return false; - return ( - editorRef.current.contains(range.startContainer) && - editorRef.current.contains(range.endContainer) - ); - }, []); - const isSelectionInsideEditor = useCallback( - (selection: Selection | null): selection is Selection => { - if (!selection || selection.rangeCount === 0 || !editorRef.current) return false; - const range = selection.getRangeAt(0); - return isRangeInsideEditor(range); - }, - [isRangeInsideEditor] - ); + const editableRef = useRef<HTMLDivElement | null>(null); + const editor = usePlateEditor({ + readOnly: disabled, + plugins: [ParagraphPlugin, MentionPlugin], + value: initialText ? toValueFromText(initialText) : EMPTY_VALUE, + }); - const rememberSelection = useCallback(() => { - const selection = window.getSelection(); - if (!isSelectionInsideEditor(selection)) return; - lastSelectionRangeRef.current = selection.getRangeAt(0).cloneRange(); - }, [isSelectionInsideEditor]); - - const restoreRememberedSelection = useCallback((): Selection | null => { - const selection = window.getSelection(); - if (!selection) return null; - if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null; - selection.removeAllRanges(); - selection.addRange(lastSelectionRangeRef.current.cloneRange()); - return selection; - }, [isRangeInsideEditor]); - - useEffect(() => { - const handleSelectionChange = () => { - if (document.activeElement !== editorRef.current) return; - rememberSelection(); - }; - document.addEventListener("selectionchange", handleSelectionChange); - return () => document.removeEventListener("selectionchange", handleSelectionChange); - }, [rememberSelection]); - - useEffect(() => { - if (!initialText || !editorRef.current) return; - editorRef.current.innerText = initialText; - editorRef.current.appendChild(document.createElement("br")); - editorRef.current.appendChild(document.createElement("br")); - setIsEmpty(false); - onChange?.(initialText, []); - editorRef.current.focus(); - const sel = window.getSelection(); - const range = document.createRange(); - range.selectNodeContents(editorRef.current); - range.collapse(false); - sel?.removeAllRanges(); - sel?.addRange(range); - const anchor = document.createElement("span"); - range.insertNode(anchor); - anchor.scrollIntoView({ block: "end" }); - anchor.remove(); - }, [initialText, onChange]); - - // Focus at the end of the editor const focusAtEnd = useCallback(() => { - if (!editorRef.current) return; - editorRef.current.focus(); + const el = editableRef.current; + if (!el) return; + el.focus(); const selection = window.getSelection(); const range = document.createRange(); - range.selectNodeContents(editorRef.current); + range.selectNodeContents(el); range.collapse(false); selection?.removeAllRanges(); selection?.addRange(range); }, []); - // Get plain text content with inline mention tokens for chips. - // This preserves the original query structure sent to the backend/LLM. - const getText = useCallback((): string => { - if (!editorRef.current) return ""; + const getCurrentValue = useCallback(() => (editor.children as ComposerValue) ?? EMPTY_VALUE, [editor]); - const extractText = (node: Node): string => { - if (node.nodeType === Node.TEXT_NODE) { - return node.textContent ?? ""; - } - - if (node.nodeType === Node.ELEMENT_NODE) { - const element = node as Element; - - // Preserve mention chips as inline @title tokens. - if (element.hasAttribute(CHIP_DATA_ATTR)) { - const title = element.querySelector("[data-mention-title='true']")?.textContent?.trim(); - if (title) { - return `@${title}`; - } - return ""; - } - - let result = ""; - for (const child of Array.from(element.childNodes)) { - result += extractText(child); - } - return result; - } - - return ""; - }; - - return extractText(editorRef.current).trim(); - }, []); - - // Get all mentioned documents - const getMentionedDocuments = useCallback((): MentionedDocument[] => { - return Array.from(mentionedDocs.values()); - }, [mentionedDocs]); - - const syncEditorState = useCallback( - (docsOverride?: Map<string, MentionedDocument>) => { - const docs = docsOverride - ? Array.from(docsOverride.values()) - : Array.from(mentionedDocs.values()); - const text = getText(); - const empty = text.length === 0 && docs.length === 0; - setIsEmpty(empty); + const emitState = useCallback( + (nextValue: ComposerValue) => { + const text = getPlainText(nextValue); + const docs = getMentionedDocuments(nextValue); onChange?.(text, docs); - }, - [getText, mentionedDocs, onChange] - ); - // Create a chip element for a document - const createChipElement = useCallback( - (doc: MentionedDocument): HTMLSpanElement => { - const chip = document.createElement("span"); - chip.setAttribute(CHIP_DATA_ATTR, "true"); - chip.setAttribute(CHIP_ID_ATTR, String(doc.id)); - chip.setAttribute(CHIP_DOCTYPE_ATTR, doc.document_type ?? "UNKNOWN"); - chip.contentEditable = "false"; - chip.className = - "inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none cursor-default"; - chip.style.userSelect = "none"; - chip.style.verticalAlign = "baseline"; - - // Container that swaps between icon and remove button on hover - const iconContainer = document.createElement("span"); - iconContainer.className = "shrink-0 flex items-center size-3 relative"; - - const iconSpan = document.createElement("span"); - iconSpan.className = "flex items-center text-muted-foreground"; - iconSpan.innerHTML = renderElementToHTML( - getConnectorIcon(doc.document_type ?? "UNKNOWN", "h-3 w-3") - ); - - const removeBtn = document.createElement("button"); - removeBtn.type = "button"; - removeBtn.className = - "size-3 items-center justify-center rounded-full text-muted-foreground transition-colors"; - removeBtn.style.display = "none"; - removeBtn.innerHTML = renderElementToHTML( - createElement(X, { className: "h-3 w-3", strokeWidth: 2.5 }) - ); - removeBtn.onclick = (e) => { - e.preventDefault(); - e.stopPropagation(); - chip.remove(); - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(docKey); - syncEditorState(next); - return next; - }); - onDocumentRemove?.(doc.id, doc.document_type); - focusAtEnd(); - }; - - const titleSpan = document.createElement("span"); - titleSpan.className = "max-w-[120px] truncate"; - titleSpan.textContent = doc.title; - titleSpan.title = doc.title; - titleSpan.setAttribute("data-mention-title", "true"); - - const statusSpan = document.createElement("span"); - statusSpan.setAttribute(CHIP_STATUS_ATTR, "true"); - statusSpan.className = "text-[10px] font-semibold opacity-80 hidden"; - - const isTouchDevice = window.matchMedia("(hover: none)").matches; - if (isTouchDevice) { - // Mobile: icon on left, title, X on right - chip.appendChild(iconSpan); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); - removeBtn.style.display = "flex"; - removeBtn.className += " ml-0.5"; - chip.appendChild(removeBtn); - } else { - // Desktop: icon/X swap on hover in the same slot - iconContainer.appendChild(iconSpan); - iconContainer.appendChild(removeBtn); - chip.addEventListener("mouseenter", () => { - iconSpan.style.display = "none"; - removeBtn.style.display = "flex"; - }); - chip.addEventListener("mouseleave", () => { - iconSpan.style.display = ""; - removeBtn.style.display = "none"; - }); - chip.appendChild(iconContainer); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); + const cursorCtx = getCursorTextContext(nextValue, editor.selection); + if (!cursorCtx) { + onMentionClose?.(); + onActionClose?.(); + return; } - return chip; + const trigger = scanActiveTrigger(cursorCtx.text, cursorCtx.cursor); + if (!trigger) { + onMentionClose?.(); + onActionClose?.(); + return; + } + + if (trigger.triggerChar === "@") { + onMentionTrigger?.(trigger.query); + onActionClose?.(); + return; + } + + onActionTrigger?.(trigger.query); + onMentionClose?.(); }, - [focusAtEnd, onDocumentRemove, syncEditorState] + [editor.selection, onActionClose, onActionTrigger, onChange, onMentionClose, onMentionTrigger] + ); + + const setValue = useCallback( + (nextValue: ComposerValue) => { + const tf = editor.tf as { setValue: (value: ComposerValue) => void }; + tf.setValue(nextValue); + emitState(nextValue); + }, + [editor, emitState] ); - // Insert a document chip at the current cursor position const insertDocumentChip = useCallback( ( doc: Pick<Document, "id" | "title" | "document_type">, options?: { removeTriggerText?: boolean } ) => { - if (!editorRef.current) return; + if (typeof doc.id !== "number" || typeof doc.title !== "string") return; + const removeTriggerText = options?.removeTriggerText ?? true; - - // Validate required fields for type safety - if (typeof doc.id !== "number" || typeof doc.title !== "string") { - console.warn("[InlineMentionEditor] Invalid document passed to insertDocumentChip:", doc); - return; - } - - const mentionDoc: MentionedDocument = { + const current = getCurrentValue(); + const selection = editor.selection; + const mentionNode: MentionElementNode = { + type: MENTION_TYPE, id: doc.id, title: doc.title, document_type: doc.document_type, + children: [{ text: "" }], }; - // Add to mentioned docs map using unique key - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => new Map(prev).set(docKey, mentionDoc)); - const nextDocs = new Map(mentionedDocs); - nextDocs.set(docKey, mentionDoc); - - // Find and remove the @query text - const selection = window.getSelection(); - const hasActiveSelection = isSelectionInsideEditor(selection); - const resolvedSelection = hasActiveSelection ? selection : restoreRememberedSelection(); - if ( - !resolvedSelection || - resolvedSelection.rangeCount === 0 || - !isSelectionInsideEditor(resolvedSelection) - ) { - // No valid in-editor selection: deterministically insert at end. - editorRef.current.focus(); - const endSelection = window.getSelection(); - if (!endSelection) return; - const endRange = document.createRange(); - endRange.selectNodeContents(editorRef.current); - endRange.collapse(false); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - const chip = createChipElement(mentionDoc); - endRange.insertNode(chip); - endRange.setStartAfter(chip); - endRange.collapse(true); - const space = document.createTextNode(" "); - endRange.insertNode(space); - endRange.setStartAfter(space); - endRange.collapse(true); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - syncEditorState(nextDocs); - rememberSelection(); + const cursorCtx = getCursorTextContext(current, selection); + if (!cursorCtx) { + const lastBlock = current[current.length - 1] ?? { type: "p", children: [{ text: "" }] }; + const appended: ComposerValue = [ + ...current.slice(0, -1), + { + ...lastBlock, + children: [...lastBlock.children, mentionNode, { text: " " }], + }, + ]; + setValue(appended); + requestAnimationFrame(focusAtEnd); return; } - // Find the @ symbol before the cursor and remove it along with any query text - const range = resolvedSelection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE && removeTriggerText) { - const text = textNode.textContent || ""; - const cursorPos = range.startOffset; - - // Find the @ symbol before cursor - let atIndex = -1; - for (let i = cursorPos - 1; i >= 0; i--) { - if (text[i] === "@") { - atIndex = i; - break; - } - } - - if (atIndex !== -1) { - // Remove @query and insert chip - const beforeAt = text.slice(0, atIndex); - const afterCursor = text.slice(cursorPos); - - // Create chip - const chip = createChipElement(mentionDoc); - - // Replace text node content - const parent = textNode.parentNode; - if (parent) { - const beforeNode = document.createTextNode(beforeAt); - const afterNode = document.createTextNode(` ${afterCursor}`); - - parent.insertBefore(beforeNode, textNode); - parent.insertBefore(chip, textNode); - parent.insertBefore(afterNode, textNode); - parent.removeChild(textNode); - - // Set cursor after the chip - const newRange = document.createRange(); - newRange.setStart(afterNode, 1); - newRange.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(newRange); - rememberSelection(); - } - } else { - // No @ found, just insert at cursor - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - - // Add space after chip - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); - } - } else { - // Either explicit non-trigger insertion or no @query present. - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); + const block = current[cursorCtx.blockIndex]; + const currentChild = getTextNode(block.children[cursorCtx.childIndex]); + if (!currentChild) { + const children = [...block.children]; + children.splice(cursorCtx.childIndex + 1, 0, mentionNode, { text: " " }); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); + return; } - syncEditorState(nextDocs); + const text = currentChild.text; + let removeStart = cursorCtx.cursor; + if (removeTriggerText) { + for (let i = cursorCtx.cursor - 1; i >= 0; i--) { + if (text[i] === "@") { + removeStart = i; + break; + } + if (text[i] === " " || text[i] === "\n") break; + } + } + + const before = text.slice(0, removeStart); + const after = text.slice(cursorCtx.cursor); + const replacement: ComposerNode[] = []; + if (before.length > 0) replacement.push({ text: before }); + replacement.push(mentionNode); + replacement.push({ text: ` ${after}` }); + + const children = [...block.children]; + children.splice(cursorCtx.childIndex, 1, ...replacement); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); }, - [ - createChipElement, - isSelectionInsideEditor, - mentionedDocs, - rememberSelection, - restoreRememberedSelection, - syncEditorState, - ] + [editor.selection, focusAtEnd, getCurrentValue, setValue] ); - // Clear the editor - const clear = useCallback(() => { - if (editorRef.current) { - editorRef.current.innerHTML = ""; - const emptyDocs = new Map<string, MentionedDocument>(); - setMentionedDocs(emptyDocs); - syncEditorState(emptyDocs); - } - }, [syncEditorState]); - - // Replace editor content with plain text and place cursor at end - const setText = useCallback( - (text: string) => { - if (!editorRef.current) return; - editorRef.current.innerText = text; - syncEditorState(); - focusAtEnd(); + const removeDocumentChip = useCallback( + (docId: number, docType?: string) => { + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => { + const children = block.children.filter((node) => { + if (!isMentionNode(node)) return true; + const match = node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (match) changed = true; + return !match; + }); + return { ...block, children: children.length ? children : [{ text: "" }] }; + }); + if (!changed) return; + setValue(next as ComposerValue); }, - [focusAtEnd, syncEditorState] + [getCurrentValue, setValue] ); const setDocumentChipStatus = useCallback( @@ -507,327 +396,142 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent docId: number, docType: string | undefined, statusLabel: string | null, - statusKind: "pending" | "processing" | "ready" | "failed" = "pending" + statusKind: MentionStatusKind = "pending" ) => { - if (!editorRef.current) return; - - const chips = editorRef.current.querySelectorAll<HTMLSpanElement>( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - const chipId = getChipId(chip); - const chipType = getChipDocType(chip); - if (chipId !== docId) continue; - if ((docType ?? "UNKNOWN") !== chipType) continue; - - const statusEl = chip.querySelector<HTMLSpanElement>(`span[${CHIP_STATUS_ATTR}="true"]`); - if (!statusEl) continue; - - if (!statusLabel) { - statusEl.textContent = ""; - statusEl.className = "text-[10px] font-semibold opacity-80 hidden"; - continue; - } - - const statusClass = - statusKind === "failed" - ? "text-destructive" - : statusKind === "processing" - ? "text-amber-700" - : statusKind === "ready" - ? "text-emerald-700" - : "text-amber-700"; - statusEl.textContent = statusLabel; - statusEl.className = `text-[10px] font-semibold opacity-80 ${statusClass}`; - } + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => ({ + ...block, + children: block.children.map((node) => { + if (!isMentionNode(node)) return node; + const sameType = (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (node.id !== docId || !sameType) return node; + changed = true; + return { + ...node, + statusLabel, + statusKind: statusLabel ? statusKind : undefined, + }; + }), + })); + if (!changed) return; + setValue(next as ComposerValue); }, - [] + [getCurrentValue, setValue] ); - const removeDocumentChip = useCallback( - (docId: number, docType?: string) => { - if (!editorRef.current) return; - const chipKey = getMentionDocKey({ id: docId, document_type: docType }); - const chips = editorRef.current.querySelectorAll<HTMLSpanElement>( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - if (getChipId(chip) === docId && getChipDocType(chip) === (docType ?? "UNKNOWN")) { - chip.remove(); - break; - } - } - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); + const clear = useCallback(() => { + setValue(EMPTY_VALUE); + }, [setValue]); + + const setText = useCallback( + (text: string) => { + setValue(toValueFromText(text)); + requestAnimationFrame(focusAtEnd); }, - [syncEditorState] + [focusAtEnd, setValue] ); - // Expose methods via ref - useImperativeHandle(ref, () => ({ - focus: () => editorRef.current?.focus(), - clear, - setText, - getText, - getMentionedDocuments, - insertDocumentChip, - removeDocumentChip, - setDocumentChipStatus, - })); + const getText = useCallback(() => getPlainText(getCurrentValue()), [getCurrentValue]); + const getMentionedDocs = useCallback( + () => getMentionedDocuments(getCurrentValue()), + [getCurrentValue] + ); - // Handle input changes - const handleInput = useCallback(() => { - if (!editorRef.current) return; + useImperativeHandle( + ref, + () => ({ + focus: () => editableRef.current?.focus(), + clear, + setText, + getText, + getMentionedDocuments: getMentionedDocs, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + }), + [clear, getMentionedDocs, getText, insertDocumentChip, removeDocumentChip, setDocumentChipStatus, setText] + ); - const text = getText(); - const empty = text.length === 0 && mentionedDocs.size === 0; - setIsEmpty(empty); - - // Unified trigger scan: find the leftmost @ or / in the current word. - // Whichever trigger was typed first owns the token — the other character - // is treated as part of the query, not as a separate trigger. - const selection = window.getSelection(); - let shouldTriggerMention = false; - let mentionQuery = ""; - let shouldTriggerAction = false; - let actionQuery = ""; - - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE) { - const textContent = textNode.textContent || ""; - const cursorPos = range.startOffset; - - let wordStart = 0; - for (let i = cursorPos - 1; i >= 0; i--) { - if (textContent[i] === " " || textContent[i] === "\n") { - wordStart = i + 1; - break; - } - } - - let triggerChar: "@" | "/" | null = null; - let triggerIndex = -1; - for (let i = wordStart; i < cursorPos; i++) { - if (textContent[i] === "@" || textContent[i] === "/") { - triggerChar = textContent[i] as "@" | "/"; - triggerIndex = i; - break; - } - } - - if (triggerChar === "@" && triggerIndex !== -1) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerMention = true; - mentionQuery = query; - } - } else if (triggerChar === "/" && triggerIndex !== -1) { - if ( - triggerIndex === 0 || - textContent[triggerIndex - 1] === " " || - textContent[triggerIndex - 1] === "\n" - ) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerAction = true; - actionQuery = query; - } - } - } - } - } - - // If no @ found before cursor, check if text contains @ at all - // If text is empty or doesn't contain @, close the mention - if (!shouldTriggerMention) { - if (text.length === 0 || !text.includes("@")) { - onMentionClose?.(); - } else { - // Text contains @ but not before cursor, close mention - onMentionClose?.(); - } - } else { - onMentionTrigger?.(mentionQuery); - } - - if (!shouldTriggerAction) { - onActionClose?.(); - } else { - onActionTrigger?.(actionQuery); - } - - // Notify parent of change - onChange?.(text, Array.from(mentionedDocs.values())); - rememberSelection(); - }, [ - getText, - mentionedDocs, - onChange, - onMentionTrigger, - onMentionClose, - onActionTrigger, - onActionClose, - rememberSelection, - ]); - - // Handle keydown const handleKeyDown = useCallback( (e: React.KeyboardEvent<HTMLDivElement>) => { - // Let parent handle navigation keys when mention popover is open - if (onKeyDown) { - onKeyDown(e); - if (e.defaultPrevented) return; - } + onKeyDown?.(e); + if (e.defaultPrevented) return; - // Handle Enter for submit (without shift) if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); onSubmit?.(); return; } - // Handle backspace on chips - if (e.key === "Backspace") { - const selection = window.getSelection(); - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - if (range.collapsed) { - // Check if cursor is right after a chip - const node = range.startContainer; - const offset = range.startOffset; - - if (node.nodeType === Node.TEXT_NODE && offset === 0) { - // Check previous sibling using type guard - const prevSibling = node.previousSibling; - if (isChipElement(prevSibling)) { - e.preventDefault(); - const chipId = getChipId(prevSibling); - const chipDocType = getChipDocType(prevSibling); - if (chipId !== null) { - prevSibling.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - return; - } - // Check if we're about to delete @ at the start - const textContent = node.textContent || ""; - if (textContent.length > 0 && textContent[0] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.TEXT_NODE && offset > 0) { - // Check if we're about to delete @ - const textContent = node.textContent || ""; - if (textContent[offset - 1] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.ELEMENT_NODE && offset > 0) { - // Check if previous child is a chip using type guard - const prevChild = (node as Element).childNodes[offset - 1]; - if (isChipElement(prevChild)) { - e.preventDefault(); - const chipId = getChipId(prevChild); - const chipDocType = getChipDocType(prevChild); - if (chipId !== null) { - prevChild.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - } - } - } - } + if (e.key !== "Backspace") return; + const selection = editor.selection; + if (!selection || !selection.anchor || !selection.focus) return; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] + ) { + return; } + if (selection.anchor.offset !== 0 || selection.focus.offset !== 0) return; + + const value = getCurrentValue(); + const block = value[selection.anchor.path[0]]; + if (!block) return; + const childIndex = selection.anchor.path[1]; + if (childIndex <= 0) return; + const prev = block.children[childIndex - 1]; + if (!isMentionNode(prev)) return; + + e.preventDefault(); + removeDocumentChip(prev.id, prev.document_type); + onDocumentRemove?.(prev.id, prev.document_type); }, - [onKeyDown, onSubmit, onDocumentRemove, onMentionClose, syncEditorState] + [ + editor.selection, + getCurrentValue, + onDocumentRemove, + onKeyDown, + onSubmit, + removeDocumentChip, + ] ); - // Handle paste - strip formatting - const handlePaste = useCallback((e: React.ClipboardEvent) => { - e.preventDefault(); - const text = e.clipboardData.getData("text/plain"); - document.execCommand("insertText", false, text); - }, []); - - // Handle composition (for IME input) - const handleCompositionStart = useCallback(() => { - isComposingRef.current = true; - }, []); - - const handleCompositionEnd = useCallback(() => { - isComposingRef.current = false; - handleInput(); - }, [handleInput]); + const editableProps = useMemo( + () => ({ + placeholder, + onPaste: (e: React.ClipboardEvent<HTMLDivElement>) => { + e.preventDefault(); + const text = e.clipboardData.getData("text/plain"); + const tf = editor.tf as { insertText: (value: string) => void }; + tf.insertText(text); + }, + onKeyDown: handleKeyDown, + }), + [editor, handleKeyDown, placeholder] + ); return ( <div className="relative w-full"> - {/* biome-ignore lint/a11y/noStaticElementInteractions: contenteditable mention editor requires a div for inline chips */} - <div - ref={editorRef} - contentEditable={!disabled} - suppressContentEditableWarning - tabIndex={disabled ? -1 : 0} - onInput={handleInput} - onKeyDown={handleKeyDown} - onPaste={handlePaste} - onCompositionStart={handleCompositionStart} - onCompositionEnd={handleCompositionEnd} - onKeyUp={rememberSelection} - onMouseUp={rememberSelection} - onBlur={rememberSelection} - className={cn( - "min-h-[24px] max-h-32 overflow-y-auto", - "text-sm outline-none", - "whitespace-pre-wrap wrap-break-word", - disabled && "opacity-50 cursor-not-allowed", - className - )} - style={{ wordBreak: "break-word" }} - data-placeholder={placeholder} - /> - {/* Placeholder with fade animation on change */} - {isEmpty && ( - <div - key={placeholder} - className="absolute top-0 left-0 pointer-events-none text-muted-foreground text-sm animate-in fade-in duration-1000" - aria-hidden="true" - > - {placeholder} - </div> - )} + <Plate + editor={editor} + onChange={({ value }) => { + emitState(value as ComposerValue); + }} + > + <PlateContent + ref={editableRef} + readOnly={disabled} + {...editableProps} + className={cn( + "min-h-[24px] max-h-32 overflow-y-auto outline-none whitespace-pre-wrap wrap-break-word", + COMPOSER_TEXT_METRICS_CLASSNAME, + disabled && "opacity-50 cursor-not-allowed", + className + )} + /> + </Plate> </div> ); } From 04da62a5541d446ccb2111dc4caed69f188806cc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:28:24 +0530 Subject: [PATCH 260/299] refactor(chat): improve AssistantMessage component with fixed comment trigger slot and enhanced visibility handling --- .../assistant-ui/assistant-message.tsx | 70 +++++++++++-------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index bfe0434b4..711bb2fe2 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => { </div> )} - <div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2"> - <AssistantActionBar /> + <div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6"> + <div className="h-full opacity-100 transition-opacity"> + <AssistantActionBar /> + </div> </div> </CitationMetadataProvider> ); @@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => { className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150" data-role="assistant" > - {/* Comment trigger — right-aligned, just below user query on all screen sizes */} - {showCommentTrigger && ( - <div className="mr-2 mb-1 flex justify-end"> - <button - ref={isDesktop ? commentTriggerRef : undefined} - type="button" - onClick={ - isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true) - } - className={cn( - "flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors", - isDesktop && isInlineOpen - ? "bg-primary/10 text-primary" - : hasComments - ? "text-primary hover:bg-primary/10" - : "text-muted-foreground hover:text-foreground hover:bg-muted" - )} - > - <MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} /> - {hasComments ? ( - <span> - {commentCount} {commentCount === 1 ? "comment" : "comments"} - </span> - ) : ( - <span>Add comment</span> - )} - </button> - </div> - )} + {/* Fixed trigger slot prevents any vertical reflow when visibility changes */} + <div className="mr-2 mb-1 flex h-7 justify-end"> + <button + ref={isDesktop ? commentTriggerRef : undefined} + type="button" + onClick={ + showCommentTrigger + ? isDesktop + ? () => setIsInlineOpen((prev) => !prev) + : () => setIsSheetOpen(true) + : undefined + } + aria-hidden={!showCommentTrigger} + tabIndex={showCommentTrigger ? 0 : -1} + className={cn( + "flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors", + "opacity-0 pointer-events-none", + showCommentTrigger && "opacity-100 pointer-events-auto", + isDesktop && isInlineOpen + ? "bg-primary/10 text-primary" + : hasComments + ? "text-primary hover:bg-primary/10" + : "text-muted-foreground hover:text-foreground hover:bg-muted" + )} + > + <MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} /> + {hasComments ? ( + <span> + {commentCount} {commentCount === 1 ? "comment" : "comments"} + </span> + ) : ( + <span>Add comment</span> + )} + </button> + </div> {/* Desktop floating comment panel — overlays on top of chat content */} {showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && ( From 5826e5264d68595fcf7b0e67c03739109ae05e50 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:39:33 +0530 Subject: [PATCH 261/299] refactor(chat): add TruncatedNameWithTooltip component in model selector --- .../components/new-chat/model-selector.tsx | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 9fe9dd8da..1a0f8c5ba 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -236,6 +236,93 @@ interface DisplayItem { isAutoMode: boolean; } +const TruncatedNameWithTooltip: React.FC<{ + text: string; + className?: string; + enableTooltip: boolean; +}> = ({ text, className, enableTooltip }) => { + const textRef = useRef<HTMLSpanElement>(null); + const openTimerRef = useRef<number | undefined>(undefined); + const [isTruncated, setIsTruncated] = useState(false); + const [open, setOpen] = useState(false); + + const recalcTruncation = useCallback(() => { + const el = textRef.current; + if (!el) return; + setIsTruncated(el.scrollWidth > el.clientWidth + 1); + }, []); + + useEffect(() => { + if (!enableTooltip) return; + const el = textRef.current; + if (!el) return; + + const raf = requestAnimationFrame(recalcTruncation); + recalcTruncation(); + + const observer = new ResizeObserver(recalcTruncation); + observer.observe(el); + if (el.parentElement) observer.observe(el.parentElement); + window.addEventListener("resize", recalcTruncation); + + return () => { + cancelAnimationFrame(raf); + observer.disconnect(); + window.removeEventListener("resize", recalcTruncation); + }; + }, [enableTooltip, recalcTruncation]); + + useEffect(() => { + // Recompute when row text changes. + void text; + requestAnimationFrame(recalcTruncation); + }, [text, recalcTruncation]); + + useEffect( + () => () => { + if (openTimerRef.current) window.clearTimeout(openTimerRef.current); + }, + [] + ); + + if (!enableTooltip) { + return ( + <span ref={textRef} className={cn("block max-w-full", className)}> + {text} + </span> + ); + } + + const handleOpenChange = (nextOpen: boolean) => { + if (openTimerRef.current) { + window.clearTimeout(openTimerRef.current); + openTimerRef.current = undefined; + } + if (!nextOpen) { + setOpen(false); + return; + } + if (!isTruncated) return; + openTimerRef.current = window.setTimeout(() => { + setOpen(true); + openTimerRef.current = undefined; + }, 220); + }; + + return ( + <Tooltip open={open} onOpenChange={handleOpenChange}> + <TooltipTrigger asChild> + <span ref={textRef} className={cn("block max-w-full", className)}> + {text} + </span> + </TooltipTrigger> + <TooltipContent side="top" align="start"> + {text} + </TooltipContent> + </Tooltip> + ); +}; + // ─── Component ────────────────────────────────────────────────────── interface ModelSelectorProps { @@ -936,7 +1023,11 @@ export function ModelSelector({ {/* Model info */} <div className="flex-1 min-w-0"> <div className="flex items-center gap-1.5"> - <span className="font-medium text-sm truncate">{config.name}</span> + <TruncatedNameWithTooltip + text={config.name} + enableTooltip={!isMobile} + className="font-medium text-sm truncate" + /> {isAutoMode && ( <Badge variant="secondary" From 7aeb8bb0a88c84afb6f23cc438d75be266701031 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 18:40:55 -0700 Subject: [PATCH 262/299] feat(markdown): enable citation rendering in MarkdownViewer and related components - Added `enableCitations` prop to `MarkdownViewer` to support interactive citation badges. - Updated instances of `MarkdownViewer` across various components to utilize the new citation feature. - Enhanced citation processing in `PlateEditor` for read-only views, ensuring citations are rendered correctly without affecting markdown serialization. - Refactored citation handling in `InlineCitation` and `MarkdownText` to improve citation context management. --- .../assistant-ui/inline-citation.tsx | 18 +- .../components/assistant-ui/markdown-text.tsx | 428 ++++++++---------- .../citation-panel/citation-panel.tsx | 2 +- .../citations/citation-renderer.tsx | 79 ++++ surfsense_web/components/document-viewer.tsx | 2 +- .../components/editor-panel/editor-panel.tsx | 9 +- .../components/editor/plate-editor.tsx | 56 ++- .../editor/plugins/citation-kit.tsx | 222 +++++++++ .../components/editor/utils/escape-mdx.ts | 2 +- .../layout/ui/tabs/DocumentTabContent.tsx | 4 +- surfsense_web/components/markdown-viewer.tsx | 100 +++- .../components/report-panel/report-panel.tsx | 5 +- .../lib/citations/citation-parser.ts | 134 ++++++ surfsense_web/lib/markdown/code-regions.ts | 8 + 14 files changed, 809 insertions(+), 260 deletions(-) create mode 100644 surfsense_web/components/citations/citation-renderer.tsx create mode 100644 surfsense_web/components/editor/plugins/citation-kit.tsx create mode 100644 surfsense_web/lib/citations/citation-parser.ts create mode 100644 surfsense_web/lib/markdown/code-regions.ts diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index 2aeba89ca..e299f2373 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -3,11 +3,11 @@ import { useQuery } from "@tanstack/react-query"; import { useSetAtom } from "jotai"; import { ExternalLink, FileText } from "lucide-react"; +import dynamic from "next/dynamic"; import type { FC } from "react"; import { useCallback, useEffect, useRef, useState } from "react"; import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context"; -import { MarkdownViewer } from "@/components/markdown-viewer"; import { Citation } from "@/components/tool-ui/citation"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; @@ -15,6 +15,16 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; +// Lazily load MarkdownViewer here to break the static import cycle: +// `markdown-viewer.tsx` → `citation-renderer.tsx` → `inline-citation.tsx` +// would otherwise pull `markdown-viewer.tsx` back in at module-init time. +// Only `SurfsenseDocCitation` (popover body) ever renders this viewer, so +// the lazy boundary is invisible to most call paths. +const MarkdownViewer = dynamic( + () => import("@/components/markdown-viewer").then((m) => m.MarkdownViewer), + { ssr: false, loading: () => <Spinner size="xs" /> } +); + interface InlineCitationProps { chunkId: number; isDocsChunk?: boolean; @@ -172,7 +182,11 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => { </p> )} {!isLoading && !error && citedChunk?.content && ( - <MarkdownViewer content={citedChunk.content} maxLength={1500} /> + <MarkdownViewer + content={citedChunk.content} + maxLength={1500} + enableCitations + /> )} {!isLoading && !error && !citedChunk?.content && ( <p className="py-4 text-xs text-muted-foreground">No content available.</p> diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 7655e10cc..2b788e88b 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -12,15 +12,26 @@ import { ExternalLinkIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useParams } from "next/navigation"; import { useTheme } from "next-themes"; -import { memo, type ReactNode } from "react"; +import { + createContext, + memo, + type ReactNode, + useCallback, + useContext, + useRef, +} from "react"; import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; -import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; +import { + type CitationUrlMap, + preprocessCitationMarkdown, +} from "@/lib/citations/citation-parser"; import { Table, TableBody, @@ -59,31 +70,30 @@ const LazyMarkdownCodeBlock = dynamic( } ); -// Storage for URL citations replaced during preprocess to avoid GFM autolink interference. -// Populated in preprocessMarkdown, consumed in parseTextWithCitations. -let _pendingUrlCitations = new Map<string, string>(); -let _urlCiteIdx = 0; +// Per-render URL placeholder map propagated to component overrides via +// React Context. Replaces the previous module-level `_pendingUrlCitations` +// state, which was unsafe under concurrent renders / SSR. +type CitationUrlMapRef = { current: CitationUrlMap }; +const EMPTY_URL_MAP: CitationUrlMap = new Map(); +const CitationUrlMapContext = createContext<CitationUrlMapRef>({ current: EMPTY_URL_MAP }); + +function useCitationUrlMap(): CitationUrlMap { + return useContext(CitationUrlMapContext).current; +} /** * Preprocess raw markdown before it reaches the remark/rehype pipeline. * - Replaces URL-based citations with safe placeholders (prevents GFM autolinks) * - Normalises LaTeX delimiters to dollar-sign syntax for remark-math */ -function preprocessMarkdown(content: string): string { +function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): string { // Replace URL-based citations with safe placeholders BEFORE markdown parsing. // GFM autolinks would otherwise convert the https://... inside [citation:URL] // into an <a> element, splitting the text and preventing our citation regex // from matching the full pattern. - _pendingUrlCitations = new Map(); - _urlCiteIdx = 0; - content = content.replace( - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g, - (_, url) => { - const key = `urlcite${_urlCiteIdx++}`; - _pendingUrlCitations.set(key, url.trim()); - return `[citation:${key}]`; - } - ); + const { content: rewritten, urlMap } = preprocessCitationMarkdown(content); + urlMapRef.current = urlMap; + content = rewritten; // All math forms are normalised to $$...$$ so we can disable single-dollar // inline math in remark-math (otherwise currency like "$3,120.00 and $0.00" @@ -116,113 +126,28 @@ function preprocessMarkdown(content: string): string { return content; } -// Matches [citation:...] with numeric IDs (incl. negative, doc- prefix, comma-separated), -// URL-based IDs from live web search, or urlciteN placeholders from preprocess. -// Also matches Chinese brackets 【】 and handles zero-width spaces that LLM sometimes inserts. -const CITATION_REGEX = - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g; - -/** - * Parses text and replaces [citation:XXX] patterns with citation components. - * Supports: - * - Numeric chunk IDs: [citation:123] - * - Doc-prefixed IDs: [citation:doc-123] - * - Comma-separated IDs: [citation:4149, 4150, 4151] - * - URL-based citations from live search: [citation:https://example.com/page] - */ -function parseTextWithCitations(text: string): ReactNode[] { - const parts: ReactNode[] = []; - let lastIndex = 0; - let match: RegExpExecArray | null; - let instanceIndex = 0; - - CITATION_REGEX.lastIndex = 0; - - match = CITATION_REGEX.exec(text); - while (match !== null) { - if (match.index > lastIndex) { - parts.push(text.substring(lastIndex, match.index)); - } - - const captured = match[1]; - - if (captured.startsWith("http://") || captured.startsWith("https://")) { - parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={captured.trim()} />); - instanceIndex++; - } else if (captured.startsWith("urlcite")) { - const url = _pendingUrlCitations.get(captured); - if (url) { - parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={url} />); - } - instanceIndex++; - } else { - const rawIds = captured.split(",").map((s) => s.trim()); - for (const rawId of rawIds) { - const isDocsChunk = rawId.startsWith("doc-"); - const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10); - parts.push( - <InlineCitation - key={`citation-${isDocsChunk ? "doc-" : ""}${chunkId}-${instanceIndex}`} - chunkId={chunkId} - isDocsChunk={isDocsChunk} - /> - ); - instanceIndex++; - } - } - - lastIndex = match.index + match[0].length; - match = CITATION_REGEX.exec(text); - } - - if (lastIndex < text.length) { - parts.push(text.substring(lastIndex)); - } - - return parts.length > 0 ? parts : [text]; -} - const MarkdownTextImpl = () => { + const urlMapRef = useRef<CitationUrlMap>(EMPTY_URL_MAP); + const preprocess = useCallback( + (content: string) => preprocessMarkdown(content, urlMapRef), + [] + ); return ( - <MarkdownTextPrimitive - smooth={false} - remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]} - rehypePlugins={[rehypeKatex]} - className="aui-md" - components={defaultComponents} - preprocess={preprocessMarkdown} - /> + <CitationUrlMapContext.Provider value={urlMapRef}> + <MarkdownTextPrimitive + smooth={false} + remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]} + rehypePlugins={[rehypeKatex]} + className="aui-md" + components={defaultComponents} + preprocess={preprocess} + /> + </CitationUrlMapContext.Provider> ); }; export const MarkdownText = memo(MarkdownTextImpl); -/** - * Helper to process children and replace citation patterns with components - */ -function processChildrenWithCitations(children: ReactNode): ReactNode { - if (typeof children === "string") { - const parsed = parseTextWithCitations(children); - return parsed.length === 1 && typeof parsed[0] === "string" ? children : parsed; - } - - if (Array.isArray(children)) { - return children.map((child) => { - if (typeof child === "string") { - const parsed = parseTextWithCitations(child); - return parsed.length === 1 && typeof parsed[0] === "string" ? ( - child - ) : ( - <span key={child}>{parsed}</span> - ); - } - return child; - }); - } - - return children; -} - function extractDomain(url: string): string { try { const parsed = new URL(url); @@ -322,92 +247,125 @@ function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { } const defaultComponents = memoizeMarkdownComponents({ - h1: ({ className, children, ...props }) => ( - <h1 - className={cn( - "aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h1> - ), - h2: ({ className, children, ...props }) => ( - <h2 - className={cn( - "aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h2> - ), - h3: ({ className, children, ...props }) => ( - <h3 - className={cn( - "aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h3> - ), - h4: ({ className, children, ...props }) => ( - <h4 - className={cn( - "aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h4> - ), - h5: ({ className, children, ...props }) => ( - <h5 - className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)} - {...props} - > - {processChildrenWithCitations(children)} - </h5> - ), - h6: ({ className, children, ...props }) => ( - <h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}> - {processChildrenWithCitations(children)} - </h6> - ), - p: ({ className, children, ...props }) => ( - <p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}> - {processChildrenWithCitations(children)} - </p> - ), - a: ({ className, children, ...props }) => ( - <a - className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)} - {...props} - > - {processChildrenWithCitations(children)} - </a> - ), - blockquote: ({ className, children, ...props }) => ( - <blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}> - {processChildrenWithCitations(children)} - </blockquote> - ), + h1: function H1({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h1 + className={cn( + "aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h1> + ); + }, + h2: function H2({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h2 + className={cn( + "aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h2> + ); + }, + h3: function H3({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h3 + className={cn( + "aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h3> + ); + }, + h4: function H4({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h4 + className={cn( + "aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h4> + ); + }, + h5: function H5({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h5 + className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h5> + ); + }, + h6: function H6({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </h6> + ); + }, + p: function P({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </p> + ); + }, + a: function A({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <a + className={cn( + "aui-md-a font-medium text-primary underline underline-offset-4", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </a> + ); + }, + blockquote: function Blockquote({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </blockquote> + ); + }, ul: ({ className, ...props }) => ( <ul className={cn("aui-md-ul my-5 ml-6 list-disc [&>li]:mt-2", className)} {...props} /> ), ol: ({ className, ...props }) => ( <ol className={cn("aui-md-ol my-5 ml-6 list-decimal [&>li]:mt-2", className)} {...props} /> ), - li: ({ className, children, ...props }) => ( - <li className={cn("aui-md-li", className)} {...props}> - {processChildrenWithCitations(children)} - </li> - ), + li: function Li({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <li className={cn("aui-md-li", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </li> + ); + }, hr: ({ className, ...props }) => ( <hr className={cn("aui-md-hr my-5 border-b", className)} {...props} /> ), @@ -422,28 +380,34 @@ const defaultComponents = memoizeMarkdownComponents({ tbody: ({ className, ...props }) => ( <TableBody className={cn("aui-md-tbody", className)} {...props} /> ), - th: ({ className, children, ...props }) => ( - <TableHead - className={cn( - "aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </TableHead> - ), - td: ({ className, children, ...props }) => ( - <TableCell - className={cn( - "aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </TableCell> - ), + th: function Th({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <TableHead + className={cn( + "aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </TableHead> + ); + }, + td: function Td({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <TableCell + className={cn( + "aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </TableCell> + ); + }, tr: ({ className, ...props }) => <TableRow className={cn("aui-md-tr", className)} {...props} />, sup: ({ className, ...props }) => ( <sup className={cn("aui-md-sup [&>a]:text-xs [&>a]:no-underline", className)} {...props} /> @@ -552,16 +516,22 @@ const defaultComponents = memoizeMarkdownComponents({ /> ); }, - strong: ({ className, children, ...props }) => ( - <strong className={cn("aui-md-strong font-semibold", className)} {...props}> - {processChildrenWithCitations(children)} - </strong> - ), - em: ({ className, children, ...props }) => ( - <em className={cn("aui-md-em", className)} {...props}> - {processChildrenWithCitations(children)} - </em> - ), + strong: function Strong({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <strong className={cn("aui-md-strong font-semibold", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </strong> + ); + }, + em: function Em({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <em className={cn("aui-md-em", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </em> + ); + }, img: ({ src, alt }) => ( <MarkdownImage src={typeof src === "string" ? src : undefined} alt={alt} /> ), diff --git a/surfsense_web/components/citation-panel/citation-panel.tsx b/surfsense_web/components/citation-panel/citation-panel.tsx index cec07b9cf..ed8acd656 100644 --- a/surfsense_web/components/citation-panel/citation-panel.tsx +++ b/surfsense_web/components/citation-panel/citation-panel.tsx @@ -169,7 +169,7 @@ export const CitationPanelContent: FC<CitationPanelContentProps> = ({ chunkId, o )} </div> <div className="text-sm"> - <MarkdownViewer content={chunk.content} /> + <MarkdownViewer content={chunk.content} enableCitations /> </div> </div> ); diff --git a/surfsense_web/components/citations/citation-renderer.tsx b/surfsense_web/components/citations/citation-renderer.tsx new file mode 100644 index 000000000..bf877f03f --- /dev/null +++ b/surfsense_web/components/citations/citation-renderer.tsx @@ -0,0 +1,79 @@ +"use client"; + +import type { ReactNode } from "react"; +import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { + type CitationToken, + type CitationUrlMap, + parseTextWithCitations, +} from "@/lib/citations/citation-parser"; + +/** + * Render a single parsed citation token as JSX. + * + * `ordinalKey` should be a stable per-render counter so duplicate identical + * citations within the same parent don't collide on `key`. The previous + * implementation in `markdown-text.tsx` used the source string itself as + * the key, which produced React warnings when two segments rendered the + * same `[citation:N]` text. + */ +export function renderCitationToken(token: CitationToken, ordinalKey: number): ReactNode { + if (token.kind === "url") { + return <UrlCitation key={`citation-url-${ordinalKey}`} url={token.url} />; + } + return ( + <InlineCitation + key={`citation-${token.isDocsChunk ? "doc-" : ""}${token.chunkId}-${ordinalKey}`} + chunkId={token.chunkId} + isDocsChunk={token.isDocsChunk} + /> + ); +} + +/** + * Walk a `ReactNode` (string, array, or arbitrary node) and replace any + * `[citation:...]` tokens inside string children with citation badges. + * + * Designed for use inside `Streamdown`/`react-markdown` `components` + * overrides where the renderer hands you `children`. Non-string children + * are returned untouched so block/phrasing structure is preserved. + */ +export function processChildrenWithCitations( + children: ReactNode, + urlMap: CitationUrlMap +): ReactNode { + if (typeof children === "string") { + const segments = parseTextWithCitations(children, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return children; + } + let ordinal = 0; + return segments.map((segment) => + typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++) + ); + } + + if (Array.isArray(children)) { + let ordinal = 0; + return children.map((child, childIndex) => { + if (typeof child === "string") { + const segments = parseTextWithCitations(child, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return child; + } + return ( + <span key={`citation-seg-${childIndex}`}> + {segments.map((segment) => + typeof segment === "string" + ? segment + : renderCitationToken(segment, ordinal++) + )} + </span> + ); + } + return child; + }); + } + + return children; +} diff --git a/surfsense_web/components/document-viewer.tsx b/surfsense_web/components/document-viewer.tsx index 0f283e567..710a04ba3 100644 --- a/surfsense_web/components/document-viewer.tsx +++ b/surfsense_web/components/document-viewer.tsx @@ -32,7 +32,7 @@ export function DocumentViewer({ title, content, trigger }: DocumentViewerProps) <DialogTitle>{title}</DialogTitle> </DialogHeader> <div className="mt-4"> - <MarkdownViewer content={content} /> + <MarkdownViewer content={content} enableCitations /> </div> </DialogContent> </Dialog> diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index df138e97e..eab07a91b 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -652,7 +652,7 @@ export function EditorPanelContent({ // Plate is heavy on multi-MB docs. <div className="h-full overflow-y-auto px-5 py-4"> {largeDocAlert} - <MarkdownViewer content={editorDoc.source_markdown} /> + <MarkdownViewer content={editorDoc.source_markdown} enableCitations /> </div> ) : renderInPlateEditor ? ( // Editable doc (FILE/NOTE) — Plate editing UX. @@ -670,12 +670,17 @@ export function EditorPanelContent({ reserveToolbarSpace defaultEditing={isEditing} className="**:[[role=toolbar]]:bg-sidebar!" + // Render `[citation:N]` badges in view mode only. + // Edit mode keeps raw text so the user can edit/delete + // tokens directly. `local_file` never reaches this branch + // (handled by the source_code editor above). + enableCitations={!isEditing && !isLocalFileMode} /> </div> </div> ) : ( <div className="h-full overflow-y-auto px-5 py-4"> - <MarkdownViewer content={editorDoc.source_markdown} /> + <MarkdownViewer content={editorDoc.source_markdown} enableCitations /> </div> )} </div> diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 7f12d3cae..c42cb991e 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -8,9 +8,11 @@ import { useEffect, useMemo, useRef } from "react"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; import { EditorSaveContext } from "@/components/editor/editor-save-context"; +import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit"; import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { Editor, EditorContainer } from "@/components/ui/editor"; +import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; /** Live editor instance returned by `usePlateEditor`. */ export type PlateEditorInstance = ReturnType<typeof usePlateEditor>; @@ -65,6 +67,14 @@ export interface PlateEditorProps { * without modifying the core editor component. */ extraPlugins?: AnyPluginConfig[]; + /** + * Render `[citation:N]` and `[citation:URL]` tokens in the deserialized + * markdown as interactive citation badges/popovers (mirrors chat). Only + * meant for read-only views — when true, `onMarkdownChange` is suppressed + * because the in-memory tree contains custom inline-void elements that + * have no markdown serialize rule. + */ + enableCitations?: boolean; } function PlateEditorContent({ @@ -103,6 +113,7 @@ export function PlateEditor({ defaultEditing = false, preset = "full", extraPlugins = [], + enableCitations = false, }: PlateEditorProps) { const lastMarkdownRef = useRef(markdown); const lastHtmlRef = useRef(html); @@ -145,6 +156,8 @@ export function PlateEditor({ ...(onSave ? [SaveShortcutPlugin] : []), // Consumer-provided extra plugins ...extraPlugins, + // Citation void inline element (read-only document viewer). + ...(enableCitations ? CitationKit : []), MarkdownPlugin.configure({ options: { remarkPlugins: [remarkGfm, remarkMath, remarkMdx], @@ -154,8 +167,18 @@ export function PlateEditor({ value: html ? (editor) => editor.api.html.deserialize({ element: html }) as Value : markdown - ? (editor) => - editor.getApi(MarkdownPlugin).markdown.deserialize(escapeMdxExpressions(markdown)) + ? (editor) => { + if (!enableCitations) { + return editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(markdown)); + } + const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); + const value = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(rewritten)); + return injectCitationNodes(value as Descendant[], urlMap) as Value; + } : undefined, }); @@ -174,13 +197,22 @@ export function PlateEditor({ useEffect(() => { if (!html && markdown !== undefined && markdown !== lastMarkdownRef.current) { lastMarkdownRef.current = markdown; - const newValue = editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(markdown)); + let newValue: Descendant[]; + if (enableCitations) { + const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); + const deserialized = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(rewritten)) as Descendant[]; + newValue = injectCitationNodes(deserialized, urlMap); + } else { + newValue = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(markdown)) as Descendant[]; + } editor.tf.reset(); - editor.tf.setValue(newValue); + editor.tf.setValue(newValue as Value); } - }, [html, markdown, editor]); + }, [html, markdown, editor, enableCitations]); // When not forced read-only, the user can toggle between editing/viewing. const canToggleMode = !readOnly && allowModeToggle; @@ -205,6 +237,16 @@ export function PlateEditor({ // (initialized to true via usePlateEditor, toggled via ModeToolbarButton). {...(readOnly ? { readOnly: true } : {})} onChange={({ value }) => { + // View-only citation mode: skip serialization. The custom + // `citation` inline-void element has no markdown serialize + // rule, so emitting changes here would overwrite + // `lastMarkdownRef.current` (and downstream copy-to-clipboard + // state in EditorPanelContent) with a tree that loses every + // citation token. `enableCitations` is only ever set in + // read-only paths, so user input cannot reach this branch + // in practice — the guard exists for the initial Plate + // normalize emit. + if (enableCitations) return; if (onHtmlChange && html) { const serialized = slateToHtml(value as Descendant[]); onHtmlChange(serialized); diff --git a/surfsense_web/components/editor/plugins/citation-kit.tsx b/surfsense_web/components/editor/plugins/citation-kit.tsx new file mode 100644 index 000000000..c90cb5e28 --- /dev/null +++ b/surfsense_web/components/editor/plugins/citation-kit.tsx @@ -0,0 +1,222 @@ +"use client"; + +import { type FC } from "react"; +import { KEYS, type Descendant } from "platejs"; +import { createPlatePlugin, type PlateElementProps } from "platejs/react"; +import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { + CITATION_REGEX, + type CitationUrlMap, + parseTextWithCitations, +} from "@/lib/citations/citation-parser"; + +/** + * Plate inline-void node modeling a single `[citation:...]` reference. + * + * Modeled after the existing `MentionPlugin` pattern in + * `inline-mention-editor.tsx` — the only confirmed pattern in this repo + * for non-text inline UI. Inline-void elements satisfy Slate's invariant + * that the editor renders both atomic widgets and surrounding text + * cleanly without breaking selection / caret semantics. + */ +export type CitationElementNode = { + type: "citation"; + kind: "chunk" | "doc" | "url"; + chunkId?: number; + url?: string; + /** Original `[citation:...]` substring for traceability/debugging. */ + rawText: string; + children: [{ text: "" }]; +}; + +const CITATION_TYPE = "citation"; + +const CitationElement: FC<PlateElementProps<CitationElementNode>> = ({ + attributes, + children, + element, +}) => { + const isUrl = element.kind === "url"; + return ( + <span {...attributes} className="inline-flex align-baseline"> + <span contentEditable={false}> + {isUrl && element.url ? ( + <UrlCitation url={element.url} /> + ) : element.chunkId !== undefined ? ( + <InlineCitation chunkId={element.chunkId} isDocsChunk={element.kind === "doc"} /> + ) : null} + </span> + {children} + </span> + ); +}; + +const CitationPlugin = createPlatePlugin({ + key: CITATION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: CITATION_TYPE, + component: CitationElement, + }, +}); + +/** Plugin kit shape used elsewhere in the editor. */ +export const CitationKit = [CitationPlugin]; + +// --------------------------------------------------------------------------- +// Slate value transform — runs after MarkdownPlugin.deserialize +// --------------------------------------------------------------------------- + +// Structural shapes used by the value transform. We cannot use Plate's +// generic Element / Text type predicates directly because `Descendant` is a +// constrained union and our predicates would over-narrow. Casting through +// these row types keeps the walker readable without fighting the types. +type SlateText = { text: string } & Record<string, unknown>; +type SlateElement = { type?: string; children: Descendant[] } & Record<string, unknown>; + +function isText(node: Descendant): boolean { + return typeof (node as { text?: unknown }).text === "string"; +} + +function asText(node: Descendant): SlateText { + return node as unknown as SlateText; +} + +function asElement(node: Descendant): SlateElement { + return node as unknown as SlateElement; +} + +/** + * Element types whose subtrees we MUST NOT inject citation void elements + * into. Each rationale documented in the citation plan: + * - `KEYS.codeBlock` / `code_line` — Plate's schema rejects inline elements + * inside code containers; the user expects literal text inside code. + * - `KEYS.link` — `<button>` inside `<a>` is invalid HTML and the link + * swallows the citation click. Mirrors the `<a>` skip in + * `MarkdownViewer`. + */ +const SKIP_SUBTREE_TYPES = new Set<string>([ + KEYS.codeBlock, + "code_line", + KEYS.link, +]); + +/** + * Build the marks portion of a Slate text node so we can preserve formatting + * (bold/italic/etc.) on the surrounding text fragments after we split. + */ +function copyMarks(textNode: SlateText): Record<string, unknown> { + const { text: _text, ...marks } = textNode; + return marks; +} + +function makeCitationElement( + rawText: string, + segment: { kind: "url"; url: string } | { kind: "chunk"; chunkId: number; isDocsChunk: boolean } +): CitationElementNode { + if (segment.kind === "url") { + return { + type: CITATION_TYPE, + kind: "url", + url: segment.url, + rawText, + children: [{ text: "" }], + }; + } + return { + type: CITATION_TYPE, + kind: segment.isDocsChunk ? "doc" : "chunk", + chunkId: segment.chunkId, + rawText, + children: [{ text: "" }], + }; +} + +/** + * Re-extract the raw `[citation:...]` substrings that produced each parsed + * segment, in source order. Lets us preserve the original literal for + * `rawText` on the inline-void element. + */ +function extractRawCitationMatches(text: string): string[] { + const matches: string[] = []; + CITATION_REGEX.lastIndex = 0; + let m: RegExpExecArray | null = CITATION_REGEX.exec(text); + while (m !== null) { + matches.push(m[0]); + m = CITATION_REGEX.exec(text); + } + return matches; +} + +function transformTextNode(node: SlateText, urlMap: CitationUrlMap): Descendant[] { + const segments = parseTextWithCitations(node.text, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return [node as unknown as Descendant]; + } + + const marks = copyMarks(node); + const rawMatches = extractRawCitationMatches(node.text); + const out: Descendant[] = []; + let citationIdx = 0; + let pendingText: string | null = null; + + const flushText = () => { + // Slate inline-void adjacency: emit an empty text node (with copied + // marks) when the citation appears at the very start/end of the text + // node so neighbours of the void always have a text sibling. + out.push({ ...marks, text: pendingText ?? "" } as unknown as Descendant); + pendingText = null; + }; + + for (const segment of segments) { + if (typeof segment === "string") { + pendingText = (pendingText ?? "") + segment; + } else { + flushText(); + const raw = rawMatches[citationIdx] ?? ""; + out.push(makeCitationElement(raw, segment) as unknown as Descendant); + citationIdx += 1; + // Always reset pendingText so the next loop iteration emits a + // trailing empty text node if no further plain text follows. + pendingText = ""; + } + } + flushText(); + + return out; +} + +function transformChildren(children: Descendant[], urlMap: CitationUrlMap): Descendant[] { + const out: Descendant[] = []; + for (const child of children) { + if (isText(child)) { + out.push(...transformTextNode(asText(child), urlMap)); + continue; + } + const elementChild = asElement(child); + const elementType = (elementChild.type ?? "") as string; + if (elementType && SKIP_SUBTREE_TYPES.has(elementType)) { + out.push(child); + continue; + } + out.push({ + ...elementChild, + children: transformChildren(elementChild.children, urlMap), + } as unknown as Descendant); + } + return out; +} + +/** + * Walk a deserialized Slate value and replace every `[citation:...]` + * substring with a `citation` inline-void element. URL placeholders + * created by `preprocessCitationMarkdown` are resolved through `urlMap`. + * + * Subtrees of `code_block`, `code_line`, and `link` are returned as-is — + * see `SKIP_SUBTREE_TYPES` above. + */ +export function injectCitationNodes(value: Descendant[], urlMap: CitationUrlMap): Descendant[] { + return transformChildren(value, urlMap); +} diff --git a/surfsense_web/components/editor/utils/escape-mdx.ts b/surfsense_web/components/editor/utils/escape-mdx.ts index cd5294b11..14839b9fc 100644 --- a/surfsense_web/components/editor/utils/escape-mdx.ts +++ b/surfsense_web/components/editor/utils/escape-mdx.ts @@ -7,7 +7,7 @@ // break the MDX parser. This module sanitises them before deserialization. // --------------------------------------------------------------------------- -const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g; +import { FENCED_OR_INLINE_CODE } from "@/lib/markdown/code-regions"; // Strip HTML comments that MDX cannot parse. // PDF converters emit <!-- PageHeader="..." -->, <!-- PageBreak -->, etc. diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index ac5463873..7ad78be41 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -316,10 +316,10 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen </Button> </AlertDescription> </Alert> - <MarkdownViewer content={doc.source_markdown} /> + <MarkdownViewer content={doc.source_markdown} enableCitations /> </> ) : ( - <MarkdownViewer content={doc.source_markdown} /> + <MarkdownViewer content={doc.source_markdown} enableCitations /> )} </div> </div> diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index c4d73e30b..b2420711a 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -3,6 +3,12 @@ import { createMathPlugin } from "@streamdown/math"; import { Streamdown, type StreamdownProps } from "streamdown"; import "katex/dist/katex.min.css"; import Image from "next/image"; +import { useMemo } from "react"; +import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; +import { + type CitationUrlMap, + preprocessCitationMarkdown, +} from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; const code = createCodePlugin({ @@ -21,8 +27,21 @@ interface MarkdownViewerProps { content: string; className?: string; maxLength?: number; + /** + * When true, render `[citation:N]` / `[citation:URL]` tokens as the + * interactive citation badges/popovers used in chat. Default `false` + * so callers that don't need citations are unchanged. + * + * Note: we deliberately do NOT override `<a>` to inject citations into + * link text — that would produce `<button>` inside `<a>` (invalid + * HTML). A `[citation:N]` token literally placed inside markdown link + * text stays as raw text. + */ + enableCitations?: boolean; } +const EMPTY_URL_MAP: CitationUrlMap = new Map(); + /** * If the entire content is wrapped in a single ```markdown or ```md * code fence, strip the fence so the inner markdown renders properly. @@ -85,14 +104,45 @@ function convertLatexDelimiters(content: string): string { return content; } -export function MarkdownViewer({ content, className, maxLength }: MarkdownViewerProps) { +export function MarkdownViewer({ + content, + className, + maxLength, + enableCitations = false, +}: MarkdownViewerProps) { const isTruncated = maxLength != null && content.length > maxLength; const displayContent = isTruncated ? content.slice(0, maxLength) : content; - const processedContent = convertLatexDelimiters(stripOuterMarkdownFence(displayContent)); + + // Preprocess for URL placeholders BEFORE LaTeX so GFM autolinks don't + // split `[citation:https://…]` apart. The preprocess is code-fence + // aware so citations inside fenced code stay literal. + const { processedContent, urlMap } = useMemo(() => { + const stripped = stripOuterMarkdownFence(displayContent); + if (!enableCitations) { + return { + processedContent: convertLatexDelimiters(stripped), + urlMap: EMPTY_URL_MAP, + }; + } + const { content: rewritten, urlMap: map } = preprocessCitationMarkdown(stripped); + return { + processedContent: convertLatexDelimiters(rewritten), + urlMap: map, + }; + }, [displayContent, enableCitations]); + + // Phrasing/block renderers wrap their string children through the + // citation renderer when `enableCitations` is on. We deliberately do + // NOT override `<a>` (would produce <button> inside <a>) and we do + // NOT touch the inline/fenced `code` paths (citations stay literal + // inside code, matching markdown-text.tsx behavior). + const wrap = (children: React.ReactNode): React.ReactNode => + enableCitations ? processChildrenWithCitations(children, urlMap) : children; + const components: StreamdownProps["components"] = { p: ({ children, ...props }) => ( <p className="my-2" {...props}> - {children} + {wrap(children)} </p> ), a: ({ children, ...props }) => ( @@ -105,31 +155,49 @@ export function MarkdownViewer({ content, className, maxLength }: MarkdownViewer {children} </a> ), - li: ({ children, ...props }) => <li {...props}>{children}</li>, + li: ({ children, ...props }) => <li {...props}>{wrap(children)}</li>, ul: ({ ...props }) => <ul className="list-disc pl-5 my-2" {...props} />, ol: ({ ...props }) => <ol className="list-decimal pl-5 my-2" {...props} />, h1: ({ children, ...props }) => ( <h1 className="text-2xl font-bold mt-6 mb-2" {...props}> - {children} + {wrap(children)} </h1> ), h2: ({ children, ...props }) => ( <h2 className="text-xl font-bold mt-5 mb-2" {...props}> - {children} + {wrap(children)} </h2> ), h3: ({ children, ...props }) => ( <h3 className="text-lg font-bold mt-4 mb-2" {...props}> - {children} + {wrap(children)} </h3> ), h4: ({ children, ...props }) => ( <h4 className="text-base font-bold mt-3 mb-1" {...props}> - {children} + {wrap(children)} </h4> ), - blockquote: ({ ...props }) => ( - <blockquote className="border-l-4 border-muted pl-4 italic my-2" {...props} /> + h5: ({ children, ...props }) => ( + <h5 className="text-sm font-bold mt-3 mb-1" {...props}> + {wrap(children)} + </h5> + ), + h6: ({ children, ...props }) => ( + <h6 className="text-xs font-bold mt-3 mb-1" {...props}> + {wrap(children)} + </h6> + ), + strong: ({ children, ...props }) => ( + <strong className="font-semibold" {...props}> + {wrap(children)} + </strong> + ), + em: ({ children, ...props }) => <em {...props}>{wrap(children)}</em>, + blockquote: ({ children, ...props }) => ( + <blockquote className="border-l-4 border-muted pl-4 italic my-2" {...props}> + {wrap(children)} + </blockquote> ), hr: ({ ...props }) => <hr className="my-4 border-muted" {...props} />, img: ({ src, alt, width: _w, height: _h, ...props }) => { @@ -163,17 +231,21 @@ export function MarkdownViewer({ content, className, maxLength }: MarkdownViewer <table className="w-full divide-y divide-border" {...props} /> </div> ), - th: ({ ...props }) => ( + th: ({ children, ...props }) => ( <th className="px-4 py-2.5 text-left text-sm font-semibold text-muted-foreground/80 bg-muted/30 border-r border-border/40 last:border-r-0" {...props} - /> + > + {wrap(children)} + </th> ), - td: ({ ...props }) => ( + td: ({ children, ...props }) => ( <td className="px-4 py-2.5 text-sm border-t border-r border-border/40 last:border-r-0" {...props} - /> + > + {wrap(children)} + </td> ), }; diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 621cf13ce..7fafc9c3b 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -516,7 +516,7 @@ export function ReportPanelContent({ ) : reportContent.content ? ( isReadOnly ? ( <div className="h-full overflow-y-auto px-5 py-4"> - <MarkdownViewer content={reportContent.content} /> + <MarkdownViewer content={reportContent.content} enableCitations /> </div> ) : ( <PlateEditor @@ -531,6 +531,9 @@ export function ReportPanelContent({ reserveToolbarSpace defaultEditing={isEditing} className="[&_[role=toolbar]]:!bg-sidebar" + // Show citation badges in view mode; raw `[citation:N]` + // text in edit mode so users can edit/delete tokens. + enableCitations={!isEditing} /> ) ) : ( diff --git a/surfsense_web/lib/citations/citation-parser.ts b/surfsense_web/lib/citations/citation-parser.ts new file mode 100644 index 000000000..6333b0f97 --- /dev/null +++ b/surfsense_web/lib/citations/citation-parser.ts @@ -0,0 +1,134 @@ +// Pure citation parsing for `[citation:...]` tokens emitted by SurfSense +// agents. No React imports — consumed by both the React renderer +// (markdown surfaces) and the Plate value transform (document viewer). +// +// The same logic previously lived inline in +// `components/assistant-ui/markdown-text.tsx` with module-level mutable +// state. This module exposes a per-call URL map so multiple concurrent +// renderers / SSR contexts can't race each other. + +import { FENCED_OR_INLINE_CODE } from "@/lib/markdown/code-regions"; + +/** + * Matches `[citation:...]` with numeric IDs (incl. negative, doc- prefix, + * comma-separated), URL-based IDs from live web search, or `urlciteN` + * placeholders produced by `preprocessCitationMarkdown`. + * + * Also matches Chinese brackets 【】 and zero-width spaces that LLMs + * sometimes emit. + */ +export const CITATION_REGEX = + /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g; + +/** A single parsed citation reference. */ +export type CitationToken = + | { kind: "url"; url: string } + | { kind: "chunk"; chunkId: number; isDocsChunk: boolean }; + +/** Output of `parseTextWithCitations` — interleaved text + citation tokens. */ +export type ParsedSegment = string | CitationToken; + +/** Per-call URL placeholder map; key is `urlciteN`, value is the original URL. */ +export type CitationUrlMap = Map<string, string>; + +/** Result of preprocessing raw markdown for downstream parsing. */ +export interface PreprocessedCitations { + /** Markdown with `[citation:URL]` tokens rewritten to `[citation:urlciteN]`. */ + content: string; + /** Lookup table to recover the original URL from each placeholder. */ + urlMap: CitationUrlMap; +} + +/** Pattern matching only URL-form citations (used during preprocessing). */ +const URL_CITATION_REGEX = + /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g; + +/** + * Replace `[citation:URL]` tokens with `[citation:urlciteN]` placeholders so + * GFM autolinks don't split the URL out of the brackets during markdown + * parsing. Returns both the rewritten content and a map for later lookup. + * + * Code-fence aware: skips fenced (``` ``` ```) and inline (`` ` ``) code + * regions so citation-shaped strings inside example code remain literal. + * + * Known limitations: `~~~` fences, 4-space indented code, and LaTeX math + * blocks are not skipped. Citation tokens inside those regions are rare in + * practice; documented in the plan. + */ +export function preprocessCitationMarkdown(content: string): PreprocessedCitations { + const urlMap: CitationUrlMap = new Map(); + let counter = 0; + + // Splitting on a regex with one capture group puts code regions at odd + // indexes (matched delimiters) and the surrounding text at even indexes. + // Only transform the even-indexed parts. + const parts = content.split(FENCED_OR_INLINE_CODE); + const transformed = parts.map((part, index) => { + if (index % 2 === 1) return part; + return part.replace(URL_CITATION_REGEX, (_match, url: string) => { + const key = `urlcite${counter++}`; + urlMap.set(key, url.trim()); + return `[citation:${key}]`; + }); + }); + + return { content: transformed.join(""), urlMap }; +} + +/** + * Parse a string into an array of plain text segments and citation tokens. + * + * Pure data — no React. The renderer module is responsible for mapping + * tokens to JSX. Negative chunk IDs are forwarded as-is so the consumer + * can decide how to render anonymous documents. + */ +export function parseTextWithCitations( + text: string, + urlMap: CitationUrlMap +): ParsedSegment[] { + const segments: ParsedSegment[] = []; + let lastIndex = 0; + let match: RegExpExecArray | null; + + CITATION_REGEX.lastIndex = 0; + match = CITATION_REGEX.exec(text); + while (match !== null) { + if (match.index > lastIndex) { + segments.push(text.substring(lastIndex, match.index)); + } + + const captured = match[1]; + + if (captured.startsWith("http://") || captured.startsWith("https://")) { + segments.push({ kind: "url", url: captured.trim() }); + } else if (captured.startsWith("urlcite")) { + const url = urlMap.get(captured); + if (url) { + segments.push({ kind: "url", url }); + } + } else { + const rawIds = captured.split(",").map((s) => s.trim()); + for (const rawId of rawIds) { + const isDocsChunk = rawId.startsWith("doc-"); + const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10); + if (!Number.isNaN(chunkId)) { + segments.push({ kind: "chunk", chunkId, isDocsChunk }); + } + } + } + + lastIndex = match.index + match[0].length; + match = CITATION_REGEX.exec(text); + } + + if (lastIndex < text.length) { + segments.push(text.substring(lastIndex)); + } + + return segments.length > 0 ? segments : [text]; +} + +/** Type guard for the citation branch of `ParsedSegment`. */ +export function isCitationToken(segment: ParsedSegment): segment is CitationToken { + return typeof segment !== "string"; +} diff --git a/surfsense_web/lib/markdown/code-regions.ts b/surfsense_web/lib/markdown/code-regions.ts new file mode 100644 index 000000000..336a87acb --- /dev/null +++ b/surfsense_web/lib/markdown/code-regions.ts @@ -0,0 +1,8 @@ +// Matches fenced (```...```) and inline (`...`) code regions. Used by MDX +// escaping and citation preprocessing — single source of truth so future +// edits stay in sync. +// +// String.split() with this capturing pattern places non-code parts at even +// indexes and matched code regions at odd indexes — preserve odd-indexed +// segments verbatim when transforming markdown. +export const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g; From c644f02d0575473118076aa55b762aaa638a56b3 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 18:42:38 -0700 Subject: [PATCH 263/299] chore: linting --- ...38_add_thread_auto_model_pinning_fields.py | 16 +- .../app/services/auto_model_pin_service.py | 15 +- .../app/tasks/chat/stream_new_chat.py | 16 +- .../services/test_auto_model_pin_service.py | 76 +++++- .../unit/test_stream_new_chat_contract.py | 26 +- .../new-chat/[[...chat_id]]/page.tsx | 120 ++++----- .../app/desktop/permissions/page.tsx | 4 +- .../agent-action-log/action-log-sheet.tsx | 5 +- .../assistant-ui/inline-citation.tsx | 6 +- .../assistant-ui/inline-mention-editor.tsx | 43 +++- .../components/assistant-ui/markdown-text.tsx | 24 +- .../components/assistant-ui/nested-scroll.tsx | 2 +- .../components/assistant-ui/thread.tsx | 2 +- .../components/assistant-ui/tool-fallback.tsx | 22 +- .../citations/citation-renderer.tsx | 4 +- .../editor/plugins/citation-kit.tsx | 10 +- .../layout/providers/LayoutDataProvider.tsx | 2 +- .../layout/ui/sidebar/DocumentsSidebar.tsx | 6 +- surfsense_web/components/markdown-viewer.tsx | 5 +- .../hooks/use-agent-actions-query.ts | 243 ++++++++---------- .../lib/chat/chat-error-classifier.ts | 45 ++-- surfsense_web/lib/chat/chat-request-errors.ts | 8 +- surfsense_web/lib/chat/stream-pipeline.ts | 8 +- surfsense_web/lib/chat/stream-side-effects.ts | 8 +- .../lib/citations/citation-parser.ts | 8 +- surfsense_web/lib/posthog/events.ts | 2 +- 26 files changed, 346 insertions(+), 380 deletions(-) diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 1ea549975..3972b84b9 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -47,19 +47,11 @@ def upgrade() -> None: def downgrade() -> None: - op.execute( - "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode" - ) - op.execute( - "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id" - ) + op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode") + op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id") - op.execute( - "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at" - ) - op.execute( - "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode" - ) + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at") + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode") op.execute( "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id" ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 6bdb60f57..6b69c91ea 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -44,7 +44,9 @@ def _is_usable_global_config(cfg: dict) -> bool: def _global_candidates() -> list[dict]: - candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)] + candidates = [ + cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) + ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -69,7 +71,9 @@ def _to_uuid(user_id: str | UUID | None) -> UUID | None: return None -async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool: +async def _is_premium_eligible( + session: AsyncSession, user_id: str | UUID | None +) -> bool: parsed = _to_uuid(user_id) if parsed is None: return False @@ -136,8 +140,7 @@ async def resolve_or_get_pinned_llm_config_id( pinned_id = thread.pinned_llm_config_id if ( not force_repin_free - and - thread.pinned_auto_mode == AUTO_FASTEST_MODE + and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id ): @@ -163,7 +166,9 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_auto_mode, ) - premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id) + premium_eligible = ( + False if force_repin_free else await _is_premium_eligible(session, user_id) + ) if premium_eligible: eligible = candidates else: diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 63c149771..5abcb63eb 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2225,9 +2225,7 @@ async def stream_new_chat( # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( - agent_config is not None - and user_id - and agent_config.is_premium + agent_config is not None and user_id and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid @@ -2271,7 +2269,9 @@ async def stream_new_chat( yield streaming_service.format_done() return - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: yield _emit_stream_error( message=llm_load_error, @@ -3086,9 +3086,7 @@ async def stream_resume_chat( _resume_premium_reserved = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( - agent_config is not None - and user_id - and agent_config.is_premium + agent_config is not None and user_id and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid @@ -3132,7 +3130,9 @@ async def stream_resume_chat( yield streaming_service.format_done() return - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: yield _emit_stream_error( message=llm_load_error, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index f08e50ba2..0a2342e05 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -66,7 +66,13 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): "GLOBAL_LLM_CONFIGS", [ {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -103,12 +109,20 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) async def _must_not_call(*_args, **_kwargs): - raise AssertionError("premium_get_usage should not be called for valid pin reuse") + raise AssertionError( + "premium_get_usage should not be called for valid pin reuse" + ) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", @@ -136,7 +150,13 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -168,8 +188,20 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -203,8 +235,20 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -238,8 +282,20 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index a1345c15c..5e6ad6abd 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -203,7 +203,10 @@ def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): def test_premium_classification_is_error_code_driven(): - classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-error-classifier.ts" + ) source = classifier_path.read_text(encoding="utf-8") assert "PREMIUM_KEYWORDS" not in source @@ -229,7 +232,8 @@ def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): user_message_path = ( - Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx" + Path(__file__).resolve().parents[3] + / "surfsense_web/components/assistant-ui/user-message.tsx" ) source = user_message_path.read_text(encoding="utf-8") @@ -238,10 +242,14 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): def test_network_send_failures_use_unified_retry_toast_message(): - classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-error-classifier.ts" + ) classifier_source = classifier_path.read_text(encoding="utf-8") request_errors_path = ( - Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts" + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-request-errors.ts" ) request_errors_source = request_errors_path.read_text(encoding="utf-8") @@ -350,15 +358,17 @@ def test_turn_status_sse_contract_exists(): / "surfsense_backend/app/tasks/chat/stream_new_chat.py" ).read_text(encoding="utf-8") state_source = ( - Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts" + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/streaming-state.ts" ).read_text(encoding="utf-8") pipeline_source = ( - Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts" + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/stream-pipeline.ts" ).read_text(encoding="utf-8") assert '"turn-status"' in stream_source assert '"status": "busy"' in stream_source assert '"status": "idle"' in stream_source - assert "type: \"data-turn-status\"" in state_source - assert "case \"data-turn-status\":" in pipeline_source + assert 'type: "data-turn-status"' in state_source + assert 'case "data-turn-status":' in pipeline_source assert "end_turn(str(chat_id))" in stream_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 1b25ca431..39201e5cc 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -19,7 +19,6 @@ import { currentThreadAtom, setTargetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; -import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { type MentionedDocumentInfo, mentionedDocumentIdsAtom, @@ -31,6 +30,7 @@ import { clearPlanOwnerRegistry, // extractWriteTodosFromContent, } from "@/atoms/chat/plan-state.atom"; +import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom"; import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms"; import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; @@ -60,20 +60,28 @@ import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; -import { - classifyChatError, - type ChatFlow, -} from "@/lib/chat/chat-error-classifier"; -import { - tagPreAcceptSendFailure, - toHttpResponseError, -} from "@/lib/chat/chat-request-errors"; +import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier"; +import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, looksLikePodcastRequest, setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; +import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; +import { + consumeSseEvents, + hasPersistableContent, + processSharedStreamEvent, +} from "@/lib/chat/stream-pipeline"; +import { + applyInterruptRequestToContentParts, + applyTurnIdToAssistantMessageList, + markInterruptDecisionOnContentParts, + mergeChatTurnIdIntoMessage, + mergeEditedInterruptAction, + readStreamedChatTurnId, +} from "@/lib/chat/stream-side-effects"; import { buildContentForPersistence, buildContentForUI, @@ -82,20 +90,6 @@ import { type ThinkingStepData, type ToolUIGate, } from "@/lib/chat/streaming-state"; -import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; -import { - consumeSseEvents, - hasPersistableContent, - processSharedStreamEvent, -} from "@/lib/chat/stream-pipeline"; -import { - applyTurnIdToAssistantMessageList, - applyInterruptRequestToContentParts, - mergeChatTurnIdIntoMessage, - mergeEditedInterruptAction, - markInterruptDecisionOnContentParts, - readStreamedChatTurnId, -} from "@/lib/chat/stream-side-effects"; import { appendMessage, createThread, @@ -112,8 +106,8 @@ import { } from "@/lib/chat/user-turn-api-parts"; import { NotFoundError } from "@/lib/error"; import { - trackChatCreated, trackChatBlocked, + trackChatCreated, trackChatErrorDetailed, trackChatMessageSent, trackChatResponseReceived, @@ -193,7 +187,8 @@ function sleep(ms: number): Promise<void> { function computeFallbackTurnCancellingRetryDelay(attempt: number): number { const safeAttempt = Math.max(1, attempt); - const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + const raw = + TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); } @@ -278,11 +273,9 @@ export default function NewChatPage() { }) => { if (!threadId) return null; try { - const normalizedContent = Array.isArray(content) - ? ([...content] as unknown[]) - : [content]; - const hasMentionedDocumentsPart = normalizedContent.some((part) => - MentionedDocumentsPartSchema.safeParse(part).success + const normalizedContent = Array.isArray(content) ? ([...content] as unknown[]) : [content]; + const hasMentionedDocumentsPart = normalizedContent.some( + (part) => MentionedDocumentsPartSchema.safeParse(part).success ); if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) { normalizedContent.push({ @@ -300,10 +293,7 @@ export default function NewChatPage() { setMessages((prev) => prev.map((m) => m.id === userMsgId - ? mergeChatTurnIdIntoMessage( - { ...m, id: newUserMsgId }, - savedUserMessage.turn_id - ) + ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) : m ) ); @@ -356,10 +346,7 @@ export default function NewChatPage() { setMessages((prev) => prev.map((m) => m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage( - { ...m, id: newMsgId }, - savedMessage.turn_id - ) + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) : m ) ); @@ -564,12 +551,7 @@ export default function NewChatPage() { toast.error(normalized.userMessage); }, - [ - currentUser?.id, - persistAssistantErrorMessage, - searchSpaceId, - setPremiumAlertForThread, - ] + [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] ); const handleStreamTerminalError = useCallback( @@ -613,35 +595,31 @@ export default function NewChatPage() { [handleChatFailure] ); - const fetchWithTurnCancellingRetry = useCallback( - async (runFetch: () => Promise<Response>) => { - const maxAttempts = 4; - for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { - const response = await runFetch(); - if (response.ok) { - return response; - } - const error = await toHttpResponseError(response); - const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; - const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; - const isRecentThreadBusyAfterCancel = - withMeta.errorCode === "THREAD_BUSY" && - Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; - if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { - const waitMs = - withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); - await sleep(waitMs); - continue; - } - throw error; + const fetchWithTurnCancellingRetry = useCallback(async (runFetch: () => Promise<Response>) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } - throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { - errorCode: "TURN_CANCELLING", - }); - }, - [] - ); + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, []); // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message diff --git a/surfsense_web/app/desktop/permissions/page.tsx b/surfsense_web/app/desktop/permissions/page.tsx index e30a76f83..ca9228272 100644 --- a/surfsense_web/app/desktop/permissions/page.tsx +++ b/surfsense_web/app/desktop/permissions/page.tsx @@ -132,8 +132,8 @@ export default function DesktopPermissionsPage() { <div className="space-y-1"> <h1 className="text-2xl font-semibold tracking-tight">System Permissions</h1> <p className="text-sm text-muted-foreground"> - SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that - require focusing the app or the active application. + SurfSense needs two macOS permissions for Screenshot Assist and for desktop features + that require focusing the app or the active application. </p> </div> </div> diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx index 32c25771a..7d27b4019 100644 --- a/surfsense_web/components/agent-action-log/action-log-sheet.tsx +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -17,10 +17,7 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Skeleton } from "@/components/ui/skeleton"; -import { - agentActionsQueryKey, - useAgentActionsQuery, -} from "@/hooks/use-agent-actions-query"; +import { agentActionsQueryKey, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; import { ActionLogItem } from "./action-log-item"; function EmptyState() { diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index e299f2373..32a29cfc9 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -182,11 +182,7 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => { </p> )} {!isLoading && !error && citedChunk?.content && ( - <MarkdownViewer - content={citedChunk.content} - maxLength={1500} - enableCitations - /> + <MarkdownViewer content={citedChunk.content} maxLength={1500} enableCitations /> )} {!isLoading && !error && !citedChunk?.content && ( <p className="py-4 text-xs text-muted-foreground">No content available.</p> diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index d92348080..c585dc80f 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,8 +1,14 @@ "use client"; -import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; -import { Plate, PlateContent, ParagraphPlugin, createPlatePlugin, usePlateEditor } from "platejs/react"; import type { PlateElementProps } from "platejs/react"; +import { + createPlatePlugin, + ParagraphPlugin, + Plate, + PlateContent, + usePlateEditor, +} from "platejs/react"; +import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; @@ -72,7 +78,11 @@ const COMPOSER_TEXT_METRICS_CLASSNAME = "text-sm leading-6"; const EMPTY_VALUE: ComposerValue = [{ type: "p", children: [{ text: "" }] }]; -const MentionElement: FC<PlateElementProps<MentionElementNode>> = ({ attributes, children, element }) => { +const MentionElement: FC<PlateElementProps<MentionElementNode>> = ({ + attributes, + children, + element, +}) => { const statusClass = element.statusKind === "failed" ? "text-destructive" @@ -255,7 +265,10 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent selection?.addRange(range); }, []); - const getCurrentValue = useCallback(() => (editor.children as ComposerValue) ?? EMPTY_VALUE, [editor]); + const getCurrentValue = useCallback( + () => (editor.children as ComposerValue) ?? EMPTY_VALUE, + [editor] + ); const emitState = useCallback( (nextValue: ComposerValue) => { @@ -379,7 +392,8 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent const next = current.map((block) => { const children = block.children.filter((node) => { if (!isMentionNode(node)) return true; - const match = node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + const match = + node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); if (match) changed = true; return !match; }); @@ -450,7 +464,15 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent removeDocumentChip, setDocumentChipStatus, }), - [clear, getMentionedDocs, getText, insertDocumentChip, removeDocumentChip, setDocumentChipStatus, setText] + [ + clear, + getMentionedDocs, + getText, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + setText, + ] ); const handleKeyDown = useCallback( @@ -488,14 +510,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent removeDocumentChip(prev.id, prev.document_type); onDocumentRemove?.(prev.id, prev.document_type); }, - [ - editor.selection, - getCurrentValue, - onDocumentRemove, - onKeyDown, - onSubmit, - removeDocumentChip, - ] + [editor.selection, getCurrentValue, onDocumentRemove, onKeyDown, onSubmit, removeDocumentChip] ); const editableProps = useMemo( diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 2b788e88b..4842e5979 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -12,14 +12,7 @@ import { ExternalLinkIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useParams } from "next/navigation"; import { useTheme } from "next-themes"; -import { - createContext, - memo, - type ReactNode, - useCallback, - useContext, - useRef, -} from "react"; +import { createContext, memo, type ReactNode, useCallback, useContext, useRef } from "react"; import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; @@ -28,10 +21,6 @@ import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/im import "katex/dist/katex.min.css"; import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; -import { - type CitationUrlMap, - preprocessCitationMarkdown, -} from "@/lib/citations/citation-parser"; import { Table, TableBody, @@ -41,6 +30,7 @@ import { TableRow, } from "@/components/ui/table"; import { useElectronAPI } from "@/hooks/use-platform"; +import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; function MarkdownCodeBlockSkeleton() { @@ -128,10 +118,7 @@ function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): stri const MarkdownTextImpl = () => { const urlMapRef = useRef<CitationUrlMap>(EMPTY_URL_MAP); - const preprocess = useCallback( - (content: string) => preprocessMarkdown(content, urlMapRef), - [] - ); + const preprocess = useCallback((content: string) => preprocessMarkdown(content, urlMapRef), []); return ( <CitationUrlMapContext.Provider value={urlMapRef}> <MarkdownTextPrimitive @@ -334,10 +321,7 @@ const defaultComponents = memoizeMarkdownComponents({ const urlMap = useCitationUrlMap(); return ( <a - className={cn( - "aui-md-a font-medium text-primary underline underline-offset-4", - className - )} + className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)} {...props} > {processChildrenWithCitations(children, urlMap)} diff --git a/surfsense_web/components/assistant-ui/nested-scroll.tsx b/surfsense_web/components/assistant-ui/nested-scroll.tsx index 5a4f8d36e..37c4790df 100644 --- a/surfsense_web/components/assistant-ui/nested-scroll.tsx +++ b/surfsense_web/components/assistant-ui/nested-scroll.tsx @@ -1,6 +1,6 @@ "use client"; -import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react"; +import { type ComponentPropsWithoutRef, forwardRef, type WheelEvent } from "react"; export type NestedScrollProps = ComponentPropsWithoutRef<"div">; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 6c02a1efa..b4a3b58c6 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -92,8 +92,8 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsSync } from "@/hooks/use-comments-sync"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; -import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; import { cn } from "@/lib/utils"; diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index cf42cf398..06082c9c7 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,19 +1,16 @@ -import { - type ToolCallMessagePartComponent, - useAuiState, -} from "@assistant-ui/react"; +import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react"; import { useQueryClient } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; +import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { DoomLoopApprovalToolUI, isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; -import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { AlertDialog, AlertDialogAction, @@ -32,10 +29,7 @@ import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/component import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; import { getToolDisplayName } from "@/contracts/enums/toolIcons"; -import { - markActionRevertedInCache, - useAgentActionsQuery, -} from "@/hooks/use-agent-actions-query"; +import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; @@ -124,8 +118,7 @@ function ToolCardRevertButton({ // Tier 1 + 2: O(1) Map-backed direct id match. Covers // ~all parity_v2 streams and any legacy stream that backfilled // ``langchainToolCallId`` via ``tool-output-available``. - const direct = - findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); + const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); if (direct) return direct; // Tier 3: position-within-turn fallback. Only kicks in when the // card has a synthetic ``call_<run_id>`` id AND no @@ -160,12 +153,7 @@ function ToolCardRevertButton({ setIsReverting(true); try { const response = await agentActionsApiService.revert(threadId, action.id); - markActionRevertedInCache( - queryClient, - threadId, - action.id, - response.new_action_id ?? null - ); + markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null); toast.success(response.message || "Action reverted."); } catch (err) { // 503 means revert is gated off on this deployment — hide the diff --git a/surfsense_web/components/citations/citation-renderer.tsx b/surfsense_web/components/citations/citation-renderer.tsx index bf877f03f..f2de4b27d 100644 --- a/surfsense_web/components/citations/citation-renderer.tsx +++ b/surfsense_web/components/citations/citation-renderer.tsx @@ -64,9 +64,7 @@ export function processChildrenWithCitations( return ( <span key={`citation-seg-${childIndex}`}> {segments.map((segment) => - typeof segment === "string" - ? segment - : renderCitationToken(segment, ordinal++) + typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++) )} </span> ); diff --git a/surfsense_web/components/editor/plugins/citation-kit.tsx b/surfsense_web/components/editor/plugins/citation-kit.tsx index c90cb5e28..1908de209 100644 --- a/surfsense_web/components/editor/plugins/citation-kit.tsx +++ b/surfsense_web/components/editor/plugins/citation-kit.tsx @@ -1,8 +1,8 @@ "use client"; -import { type FC } from "react"; -import { KEYS, type Descendant } from "platejs"; +import { type Descendant, KEYS } from "platejs"; import { createPlatePlugin, type PlateElementProps } from "platejs/react"; +import type { FC } from "react"; import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; import { CITATION_REGEX, @@ -97,11 +97,7 @@ function asElement(node: Descendant): SlateElement { * swallows the citation click. Mirrors the `<a>` skip in * `MarkdownViewer`. */ -const SKIP_SUBTREE_TYPES = new Set<string>([ - KEYS.codeBlock, - "code_line", - KEYS.link, -]); +const SKIP_SUBTREE_TYPES = new Set<string>([KEYS.codeBlock, "code_line", KEYS.link]); /** * Build the marks portion of a Slate text node so we can preserve formatting diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 3efdab03b..afd888f48 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -26,9 +26,9 @@ import { type Tab, } from "@/atoms/tabs/tabs.atom"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { ActionLogSheet } from "@/components/agent-action-log/action-log-sheet"; import { SearchSpaceSettingsDialog } from "@/components/settings/search-space-settings-dialog"; import { TeamDialog } from "@/components/settings/team-dialog"; -import { ActionLogSheet } from "@/components/agent-action-log/action-log-sheet"; import { UserSettingsDialog } from "@/components/settings/user-settings-dialog"; import { AlertDialog, diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index d20aea2cd..bf4de6454 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -23,9 +23,7 @@ import { useTranslations } from "next-intl"; import type React from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; -import { - mentionedDocumentsAtom, -} from "@/atoms/chat/mentioned-documents.atom"; +import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms"; @@ -74,12 +72,12 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI, usePlatform } from "@/hooks/use-platform"; -import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index b2420711a..6caf01917 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -5,10 +5,7 @@ import "katex/dist/katex.min.css"; import Image from "next/image"; import { useMemo } from "react"; import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; -import { - type CitationUrlMap, - preprocessCitationMarkdown, -} from "@/lib/citations/citation-parser"; +import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; const code = createCodePlugin({ diff --git a/surfsense_web/hooks/use-agent-actions-query.ts b/surfsense_web/hooks/use-agent-actions-query.ts index 9a722fb2e..114c79567 100644 --- a/surfsense_web/hooks/use-agent-actions-query.ts +++ b/surfsense_web/hooks/use-agent-actions-query.ts @@ -88,71 +88,68 @@ export function applyActionLogSse( searchSpaceId, event, }); - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - const placeholder: AgentAction = { - id: event.id, - thread_id: threadId, - user_id: null, - search_space_id: searchSpaceId, - tool_name: event.tool_name, - args: null, - result_id: null, - reversible: event.reversible, - reverse_descriptor: event.reverse_descriptor_present ? {} : null, - error: event.error ? {} : null, - reverse_of: null, - reverted_by_action_id: null, - is_revert_action: false, - tool_call_id: event.lc_tool_call_id, - chat_turn_id: event.chat_turn_id, - created_at: event.created_at ?? new Date().toISOString(), - }; - if (!prev) { - return { - items: [placeholder], - total: 1, - page: 0, - page_size: ACTION_LOG_PAGE_SIZE, - has_more: false, - }; - } - const existingIdx = prev.items.findIndex((a) => a.id === event.id); - if (existingIdx >= 0) { - const merged = [...prev.items]; - const existing = merged[existingIdx]; - if (existing) { - merged[existingIdx] = { - ...existing, - reversible: event.reversible, - tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, - chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, - }; - } - dbg("applyActionLogSse: merged into existing entry", { - id: event.id, - tool_call_id: merged[existingIdx]?.tool_call_id, - reversible: merged[existingIdx]?.reversible, - }); - return { ...prev, items: merged }; - } - dbg("applyActionLogSse: appended new placeholder", { - id: event.id, - tool_call_id: placeholder.tool_call_id, - tool_name: placeholder.tool_name, - reversible: placeholder.reversible, - cacheSizeAfter: prev.items.length + 1, - }); - // REST returns newest-first — keep that ordering when - // the server eventually refetches by prepending. + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + const placeholder: AgentAction = { + id: event.id, + thread_id: threadId, + user_id: null, + search_space_id: searchSpaceId, + tool_name: event.tool_name, + args: null, + result_id: null, + reversible: event.reversible, + reverse_descriptor: event.reverse_descriptor_present ? {} : null, + error: event.error ? {} : null, + reverse_of: null, + reverted_by_action_id: null, + is_revert_action: false, + tool_call_id: event.lc_tool_call_id, + chat_turn_id: event.chat_turn_id, + created_at: event.created_at ?? new Date().toISOString(), + }; + if (!prev) { return { - ...prev, - items: [placeholder, ...prev.items], - total: prev.total + 1, + items: [placeholder], + total: 1, + page: 0, + page_size: ACTION_LOG_PAGE_SIZE, + has_more: false, }; } - ); + const existingIdx = prev.items.findIndex((a) => a.id === event.id); + if (existingIdx >= 0) { + const merged = [...prev.items]; + const existing = merged[existingIdx]; + if (existing) { + merged[existingIdx] = { + ...existing, + reversible: event.reversible, + tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, + chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, + }; + } + dbg("applyActionLogSse: merged into existing entry", { + id: event.id, + tool_call_id: merged[existingIdx]?.tool_call_id, + reversible: merged[existingIdx]?.reversible, + }); + return { ...prev, items: merged }; + } + dbg("applyActionLogSse: appended new placeholder", { + id: event.id, + tool_call_id: placeholder.tool_call_id, + tool_name: placeholder.tool_name, + reversible: placeholder.reversible, + cacheSizeAfter: prev.items.length + 1, + }); + // REST returns newest-first — keep that ordering when + // the server eventually refetches by prepending. + return { + ...prev, + items: [placeholder, ...prev.items], + total: prev.total + 1, + }; + }); } /** @@ -170,33 +167,30 @@ export function applyActionLogUpdatedSse( id, reversible, }); - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - if (!prev) { - dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { - threadId, - id, - }); - return prev; - } - let mutated = false; - const items = prev.items.map((a) => { - if (a.id !== id) return a; - mutated = true; - return { ...a, reversible }; + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + if (!prev) { + dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { + threadId, + id, }); - if (!mutated) { - dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { - threadId, - id, - cacheSize: prev.items.length, - cacheIds: prev.items.map((a) => a.id), - }); - } - return mutated ? { ...prev, items } : prev; + return prev; } - ); + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + return { ...a, reversible }; + }); + if (!mutated) { + dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { + threadId, + id, + cacheSize: prev.items.length, + cacheIds: prev.items.map((a) => a.id), + }); + } + return mutated ? { ...prev, items } : prev; + }); } /** @@ -214,24 +208,21 @@ export function markActionRevertedInCache( id: number, newActionId: number | null ): void { - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - if (!prev) return prev; - let mutated = false; - const items = prev.items.map((a) => { - if (a.id !== id) return a; - mutated = true; - // ``-1`` is a sentinel meaning "we know it was reverted - // but the server didn't tell us the new row's id". - return { - ...a, - reverted_by_action_id: newActionId ?? -1, - }; - }); - return mutated ? { ...prev, items } : prev; - } - ); + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + if (!prev) return prev; + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + // ``-1`` is a sentinel meaning "we know it was reverted + // but the server didn't tell us the new row's id". + return { + ...a, + reverted_by_action_id: newActionId ?? -1, + }; + }); + return mutated ? { ...prev, items } : prev; + }); } /** @@ -245,21 +236,18 @@ export function applyRevertTurnResultsToCache( entries: Array<{ id: number; newActionId: number | null }> ): void { if (entries.length === 0) return; - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - if (!prev) return prev; - const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); - let mutated = false; - const items = prev.items.map((a) => { - if (!lookup.has(a.id)) return a; - mutated = true; - const newActionId = lookup.get(a.id) ?? null; - return { ...a, reverted_by_action_id: newActionId ?? -1 }; - }); - return mutated ? { ...prev, items } : prev; - } - ); + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + if (!prev) return prev; + const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); + let mutated = false; + const items = prev.items.map((a) => { + if (!lookup.has(a.id)) return a; + mutated = true; + const newActionId = lookup.get(a.id) ?? null; + return { ...a, reverted_by_action_id: newActionId ?? -1 }; + }); + return mutated ? { ...prev, items } : prev; + }); } /** @@ -271,10 +259,7 @@ export function applyRevertTurnResultsToCache( * knob — pass ``false`` to keep the query dormant when the consumer * doesn't yet have a thread id. */ -export function useAgentActionsQuery( - threadId: number | null, - options: { enabled?: boolean } = {} -) { +export function useAgentActionsQuery(threadId: number | null, options: { enabled?: boolean } = {}) { const enabled = (options.enabled ?? true) && threadId !== null; const query = useQuery({ queryKey: agentActionsQueryKey(threadId), @@ -336,10 +321,7 @@ export function useAgentActionsQuery( else m.set(key, [a]); } for (const bucket of m.values()) { - bucket.sort( - (a, b) => - new Date(a.created_at).getTime() - new Date(b.created_at).getTime() - ); + bucket.sort((a, b) => new Date(a.created_at).getTime() - new Date(b.created_at).getTime()); } return m; }, [items]); @@ -396,10 +378,7 @@ export function useAgentActionsQuery( ); const findByChatTurnAndTool = useCallback( - ( - chatTurnId: string | null | undefined, - toolName: string | null | undefined - ): AgentAction[] => { + (chatTurnId: string | null | undefined, toolName: string | null | undefined): AgentAction[] => { if (!chatTurnId || !toolName) return []; return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? []; }, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 7dfbfc1a1..95d9848f2 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -53,7 +53,10 @@ function getErrorMessage(error: unknown): string { } } -function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined { +function getErrorCode( + error: unknown, + parsedJson: Record<string, unknown> | null +): string | undefined { if (error instanceof Error) { const withCode = error as Error & { errorCode?: string; code?: string }; if (withCode.errorCode) return withCode.errorCode; @@ -138,8 +141,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError severity: "info", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: - "Buy more tokens to continue with this model, or switch to a free model.", + userMessage: "Buy more tokens to continue with this model, or switch to a free model.", assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE, rawMessage, errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED", @@ -147,9 +149,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "TURN_CANCELLING" - ) { + if (errorCode === "TURN_CANCELLING") { return { kind: "thread_busy", channel: "toast", @@ -163,16 +163,15 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "THREAD_BUSY" - ) { + if (errorCode === "THREAD_BUSY") { return { kind: "thread_busy", channel: "toast", severity: "warn", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: "Another response is still finishing for this thread. Please try again in a moment.", + userMessage: + "Another response is still finishing for this thread. Please try again in a moment.", rawMessage, errorCode: errorCode ?? "THREAD_BUSY", details: { flow: input.flow }, @@ -193,10 +192,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "AUTH_EXPIRED" || - errorCode === "UNAUTHORIZED" - ) { + if (errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED") { return { kind: "auth_expired", channel: "toast", @@ -210,10 +206,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "RATE_LIMITED" || - providerTypeNormalized === "rate_limit_error" - ) { + if (errorCode === "RATE_LIMITED" || providerTypeNormalized === "rate_limit_error") { return { kind: "rate_limited", channel: "toast", @@ -242,9 +235,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "STREAM_PARSE_ERROR" - ) { + if (errorCode === "STREAM_PARSE_ERROR") { return { kind: "stream_parse_error", channel: "toast", @@ -258,9 +249,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "TOOL_EXECUTION_ERROR" - ) { + if (errorCode === "TOOL_EXECUTION_ERROR") { return { kind: "tool_execution_error", channel: "toast", @@ -274,9 +263,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "PERSIST_MESSAGE_FAILED" - ) { + if (errorCode === "PERSIST_MESSAGE_FAILED") { return { kind: "persist_message_failed", channel: "toast", @@ -290,9 +277,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "SERVER_ERROR" - ) { + if (errorCode === "SERVER_ERROR") { return { kind: "server_error", channel: "toast", diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index 708831354..e0dfb3cc4 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -74,13 +74,9 @@ export async function toHttpResponseError( : Number.isFinite(retryAfterSeconds) ? Math.max(0, Math.round(retryAfterSeconds * 1000)) : undefined; - const retryAfterMs = - detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; + const retryAfterMs = detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; const message = - detailNestedMessage ?? - detailMessage ?? - topLevelMessage ?? - `Backend error: ${response.status}`; + detailNestedMessage ?? detailMessage ?? topLevelMessage ?? `Backend error: ${response.status}`; return Object.assign(new Error(message), { errorCode, retryAfterMs }); } diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts index c9118f949..c76781083 100644 --- a/surfsense_web/lib/chat/stream-pipeline.ts +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -72,8 +72,12 @@ function toStreamTerminalError( }); } -export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { - const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; +export function processSharedStreamEvent( + parsed: SSEEvent, + context: SharedStreamEventContext +): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = + context; const { contentParts, toolCallIndices } = contentPartsState; switch (parsed.type) { diff --git a/surfsense_web/lib/chat/stream-side-effects.ts b/surfsense_web/lib/chat/stream-side-effects.ts index 9cb349458..5483ff14b 100644 --- a/surfsense_web/lib/chat/stream-side-effects.ts +++ b/surfsense_web/lib/chat/stream-side-effects.ts @@ -16,9 +16,7 @@ export type EditedInterruptAction = { args: Record<string, unknown>; }; -function readInterruptActions( - interruptData: Record<string, unknown> -): InterruptActionRequest[] { +function readInterruptActions(interruptData: Record<string, unknown>): InterruptActionRequest[] { return (interruptData.action_requests ?? []) as InterruptActionRequest[]; } @@ -121,7 +119,5 @@ export function applyTurnIdToAssistantMessageList( assistantMsgId: string, turnId: string ): ThreadMessageLike[] { - return messages.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m - ); + return messages.map((m) => (m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m)); } diff --git a/surfsense_web/lib/citations/citation-parser.ts b/surfsense_web/lib/citations/citation-parser.ts index 6333b0f97..533c644c2 100644 --- a/surfsense_web/lib/citations/citation-parser.ts +++ b/surfsense_web/lib/citations/citation-parser.ts @@ -40,8 +40,7 @@ export interface PreprocessedCitations { } /** Pattern matching only URL-form citations (used during preprocessing). */ -const URL_CITATION_REGEX = - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g; +const URL_CITATION_REGEX = /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g; /** * Replace `[citation:URL]` tokens with `[citation:urlciteN]` placeholders so @@ -82,10 +81,7 @@ export function preprocessCitationMarkdown(content: string): PreprocessedCitatio * tokens to JSX. Negative chunk IDs are forwarded as-is so the consumer * can decide how to render anonymous documents. */ -export function parseTextWithCitations( - text: string, - urlMap: CitationUrlMap -): ParsedSegment[] { +export function parseTextWithCitations(text: string, urlMap: CitationUrlMap): ParsedSegment[] { const segments: ParsedSegment[] = []; let lastIndex = 0; let match: RegExpExecArray | null; diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 30e58215a..f9eb6b312 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -1,6 +1,6 @@ import posthog from "posthog-js"; import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants"; -import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier"; +import type { ChatErrorKind, ChatErrorSeverity, ChatFlow } from "@/lib/chat/chat-error-classifier"; /** * PostHog Analytics Event Definitions From 1efed5e489763a655eba5fa3ed86c2d0dd4fa800 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 20:28:41 -0700 Subject: [PATCH 264/299] chore: add debug environment variables for macOS codesigning troubleshooting --- .github/workflows/desktop-release.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index e356bd3e5..ad1c128bc 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -144,6 +144,11 @@ jobs: APPLE_ID: ${{ secrets.APPLE_ID }} APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + # TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed. + # Surfaces the exact codesign / notarize commands electron-builder spawns, + # so we can see which subprocess hangs. + DEBUG: electron-builder,electron-osx-sign*,@electron/notarize* + ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true" # Service principal credentials for Azure.Identity EnvironmentCredential used by the # TrustedSigning PowerShell module. Only populated when signing is enabled. # electron-builder 26 does not yet support OIDC federated tokens for Azure signing, From 360b5f8e3ad8056e6db171a6cb34fe5a7899dee4 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 20:47:30 -0700 Subject: [PATCH 265/299] chore: update environment variables for improved macOS codesigning debugging --- .github/workflows/notary-status.yml | 60 +++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 .github/workflows/notary-status.yml diff --git a/.github/workflows/notary-status.yml b/.github/workflows/notary-status.yml new file mode 100644 index 000000000..5c7c42038 --- /dev/null +++ b/.github/workflows/notary-status.yml @@ -0,0 +1,60 @@ +name: Notary status check + +# One-off diagnostic workflow. Queries Apple's notary service to see if your +# submissions are queued, in progress, accepted, or rejected. Useful when a +# notarization seems "hung" — most often the queue itself, especially on a +# brand-new Apple Developer account. +# +# Run via: Actions tab -> "Notary status check" -> Run workflow. +# Inputs are optional; if you provide a submission ID, it also fetches that +# submission's full Apple log. +# +# Safe to delete after diagnosis. + +on: + workflow_dispatch: + inputs: + submission_id: + description: 'Optional: submission UUID to fetch full Apple log for' + required: false + default: '' + +jobs: + status: + runs-on: macos-latest + steps: + - name: List recent notarization submissions + env: + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + run: | + set -euo pipefail + echo "::group::Submission history (most recent first)" + xcrun notarytool history \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" + echo "::endgroup::" + + - name: Inspect specific submission (if id provided) + if: ${{ inputs.submission_id != '' }} + env: + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + SUBMISSION_ID: ${{ inputs.submission_id }} + run: | + set -euo pipefail + echo "::group::Submission info" + xcrun notarytool info "$SUBMISSION_ID" \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" + echo "::endgroup::" + echo "::group::Apple's processing log for this submission" + xcrun notarytool log "$SUBMISSION_ID" \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" || true + echo "::endgroup::" From e57c3a7d0c0f4f1fbe29382a97635fb01e5db44a Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Fri, 1 May 2026 05:10:53 -0700 Subject: [PATCH 266/299] feat: prompt caching - Updated `litellm` dependency version from `1.83.4` to `1.83.7`. - Adjusted `aiohttp` version from `3.13.5` to `3.13.4` in the lock file. - Implemented `apply_litellm_prompt_caching` in `chat_deepagent.py` to improve prompt caching. - Added model name resolution logic in `chat_deepagent.py` to ensure correct provider-variant dispatch. - Enhanced `llm_config.py` to configure prompt caching for various LLM providers. - Updated tests to verify correct model name forwarding and prompt caching behavior. --- .../app/agents/new_chat/chat_deepagent.py | 60 ++- .../app/agents/new_chat/llm_config.py | 22 +- .../app/agents/new_chat/prompt_caching.py | 166 +++++++++ .../app/services/llm_router_service.py | 35 +- surfsense_backend/pyproject.toml | 2 +- .../agents/new_chat/prompts/test_composer.py | 25 ++ .../agents/new_chat/test_prompt_caching.py | 350 ++++++++++++++++++ .../test_resolve_prompt_model_name.py | 117 ++++++ .../unit/test_stream_new_chat_contract.py | 36 ++ surfsense_backend/uv.lock | 160 ++++---- .../components/pricing/pricing-section.tsx | 1 - .../settings/more-pages-content.tsx | 59 +-- 12 files changed, 877 insertions(+), 156 deletions(-) create mode 100644 surfsense_backend/app/agents/new_chat/prompt_caching.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index fdd72ea92..c0e9a3b96 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent`` This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable subclass of the default ``FilesystemMiddleware`` — while preserving every other behaviour that ``create_deep_agent`` provides (todo-list, subagents, -summarisation, prompt-caching, etc.). +summarisation, etc.). Prompt caching is configured at LLM-build time via +``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather +than as a middleware. """ import asyncio @@ -33,7 +35,6 @@ from langchain.agents.middleware import ( TodoListMiddleware, ToolCallLimitMiddleware, ) -from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer @@ -74,6 +75,7 @@ from app.agents.new_chat.plugin_loader import ( load_allowed_plugin_names_from_env, load_plugin_middlewares, ) +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.agents.new_chat.subagents import build_specialized_subagents from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, @@ -94,6 +96,39 @@ from app.utils.perf import get_perf_logger _perf_log = get_perf_logger() + +def _resolve_prompt_model_name( + agent_config: AgentConfig | None, + llm: BaseChatModel, +) -> str | None: + """Resolve the model id to feed to provider-variant detection. + + Preference order (matches the established idiom in + ``llm_router_service.py`` — see ``params.get("base_model") or + params.get("model", "")`` usages there): + + 1. ``agent_config.litellm_params["base_model"]`` — required for Azure + deployments where ``model_name`` is the deployment slug, not the + underlying family. Without this, a deployment named e.g. + ``"prod-chat-001"`` would silently miss every provider regex. + 2. ``agent_config.model_name`` — the user's configured model id. + 3. ``getattr(llm, "model", None)`` — fallback for direct callers that + don't supply an ``AgentConfig`` (currently a defensive path; all + production callers pass ``agent_config``). + + Returns ``None`` when nothing is available; ``compose_system_prompt`` + treats that as the ``"default"`` variant (no provider block emitted). + """ + if agent_config is not None: + params = agent_config.litellm_params or {} + base_model = params.get("base_model") + if isinstance(base_model, str) and base_model.strip(): + return base_model + if agent_config.model_name: + return agent_config.model_name + return getattr(llm, "model", None) + + # ============================================================================= # Connector Type Mapping # ============================================================================= @@ -279,6 +314,14 @@ async def create_surfsense_deep_agent( ) """ _t_agent_total = time.perf_counter() + + # Layer thread-aware prompt caching onto the LLM. Idempotent with the + # build-time call in ``llm_config.py``; this run merely adds + # ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family + # configs now that ``thread_id`` is known. No-op when ``thread_id`` is + # None or the provider is non-OpenAI-family. + apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) + filesystem_selection = filesystem_selection or FilesystemSelection() backend_resolver = build_backend_resolver( filesystem_selection, @@ -398,6 +441,7 @@ async def create_surfsense_deep_agent( enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, mcp_connector_tools=_mcp_connector_tools, + model_name=_resolve_prompt_model_name(agent_config, llm), ) else: system_prompt = build_surfsense_system_prompt( @@ -405,6 +449,7 @@ async def create_surfsense_deep_agent( enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, mcp_connector_tools=_mcp_connector_tools, + model_name=_resolve_prompt_model_name(agent_config, llm), ) _perf_log.info( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 @@ -568,7 +613,6 @@ def _build_compiled_agent_blocking( ), create_surfsense_compaction_middleware(llm, StateBackend), PatchToolCallsMiddleware(), - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key] @@ -1006,12 +1050,12 @@ def _build_compiled_agent_blocking( action_log_mw, PatchToolCallsMiddleware(), DedupHITLToolCallsMiddleware(agent_tools=list(tools)), - # Plugin slot — sits just before AnthropicCache so plugin-side - # transforms see the final tool result and run before any - # caching heuristics. Multiple plugins in declared order; loader - # filtered by the admin allowlist already. + # Plugin slot — sits at the tail so plugin-side transforms see the + # final tool result. Prompt caching is now applied at LLM build time + # via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no + # caching middleware is needed here. Multiple plugins run in declared + # order; loader filtered by the admin allowlist already. *plugin_middlewares, - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] deepagent_middleware = [m for m in deepagent_middleware if m is not None] diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 58d8f84d0..99bb719f6 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -27,6 +27,7 @@ from litellm import get_model_info from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -494,6 +495,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) + # Configure LiteLLM-native prompt caching (cache_control_injection_points + # for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.). + # ``agent_config=None`` here — the YAML path doesn't have provider intent + # in a structured form, so we set only the universal injection points. + apply_litellm_prompt_caching(llm) return llm @@ -518,7 +524,16 @@ def create_chat_litellm_from_agent_config( print("Error: Auto mode requested but LLM Router not initialized") return None try: - return get_auto_mode_llm() + router_llm = get_auto_mode_llm() + if router_llm is not None: + # Universal cache_control_injection_points only — auto-mode + # fans out across providers, so OpenAI-only kwargs (e.g. + # ``prompt_cache_key``) are left off here. ``drop_params`` + # would strip them at the provider boundary anyway, but + # there's no point setting them when we don't know the + # destination. + apply_litellm_prompt_caching(router_llm, agent_config=agent_config) + return router_llm except Exception as e: print(f"Error creating ChatLiteLLMRouter: {e}") return None @@ -549,4 +564,9 @@ def create_chat_litellm_from_agent_config( llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) + # Build-time prompt caching: sets ``cache_control_injection_points`` for + # all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``. + # Per-thread ``prompt_cache_key`` is layered on later in + # ``create_surfsense_deep_agent`` once ``thread_id`` is known. + apply_litellm_prompt_caching(llm, agent_config=agent_config) return llm diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py new file mode 100644 index 000000000..86bc57725 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py @@ -0,0 +1,166 @@ +"""LiteLLM-native prompt caching configuration for SurfSense agents. + +Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never +activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` +gate always failed) with LiteLLM's universal caching mechanism. + +Coverage: + +- Marker-based providers (need ``cache_control`` injection, which LiteLLM + performs automatically when ``cache_control_injection_points`` is set): + ``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``, + ``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/`` + (Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM). +- Auto-cached (LiteLLM strips the marker silently): ``openai/``, + ``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024 + tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``. + +We inject **two** breakpoints per request: + +- ``role: system`` — pins the SurfSense system prompt (provider variant, + citation rules, tool catalog, KB tree, skills metadata) into the cache. +- ``index: -1`` — pins the latest message so multi-turn savings compound: + Anthropic-family providers use longest-matching-prefix lookup, so turn + N+1 still reads turn N's cache up to the shared prefix. + +For OpenAI-family configs we additionally pass: + +- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that + raises hit rate by sending requests with a shared prefix to the same + backend. +- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default + 5-10 min in-memory cache. + +Safety net: ``litellm.drop_params=True`` is set globally in +``app.services.llm_service`` at module-load time. Any kwarg the destination +provider doesn't recognise is auto-stripped at the provider transformer +layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on +``prompt_cache_key`` etc. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from langchain_core.language_models import BaseChatModel + +if TYPE_CHECKING: + from app.agents.new_chat.llm_config import AgentConfig + +logger = logging.getLogger(__name__) + + +# Two-breakpoint policy: system + latest message. See module docstring for +# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we +# use 2 here, leaving headroom for Phase-2 tool caching. +_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( + {"location": "message", "role": "system"}, + {"location": "message", "index": -1}, +) + +# Providers (uppercase ``AgentConfig.provider`` values) that natively expose +# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and +# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers +# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without +# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU, +# MINIMAX), so we can't infer family from the litellm prefix alone. +_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"}) + + +def _is_router_llm(llm: BaseChatModel) -> bool: + """Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import. + + Importing ``app.services.llm_router_service`` at module-load time would + create a cycle via ``llm_config -> prompt_caching -> llm_router_service``. + Class-name comparison is sufficient since the class is defined in a + single place. + """ + return type(llm).__name__ == "ChatLiteLLMRouter" + + +def _is_openai_family_config(agent_config: AgentConfig | None) -> bool: + """Whether the config targets an OpenAI-style prompt-cache surface. + + Strict — only returns True when the user explicitly chose OPENAI, + DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` / + ``YAMLConfig``. Auto-mode and custom providers return False because + we can't statically know the destination. + """ + if agent_config is None or not agent_config.provider: + return False + if agent_config.is_auto_mode: + return False + if agent_config.custom_provider: + return False + return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS + + +def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None: + """Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail. + + Initialises the field to ``{}`` when present-but-None on a Pydantic v2 + model. Returns ``None`` if the LLM type doesn't expose a writable + ``model_kwargs`` attribute (caller should treat as no-op). + """ + model_kwargs = getattr(llm, "model_kwargs", None) + if isinstance(model_kwargs, dict): + return model_kwargs + try: + llm.model_kwargs = {} # type: ignore[attr-defined] + except Exception: + return None + refreshed = getattr(llm, "model_kwargs", None) + return refreshed if isinstance(refreshed, dict) else None + + +def apply_litellm_prompt_caching( + llm: BaseChatModel, + *, + agent_config: AgentConfig | None = None, + thread_id: int | None = None, +) -> None: + """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. + + Idempotent — values already present in ``llm.model_kwargs`` (e.g. from + ``agent_config.litellm_params`` overrides) are preserved. Mutates + ``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion`` + via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge + in our custom ``ChatLiteLLMRouter``. + + Args: + llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance. + agent_config: Optional ``AgentConfig`` driving provider-specific + behaviour. When omitted (or auto-mode), only the universal + ``cache_control_injection_points`` are set. + thread_id: Optional thread id used to construct a per-thread + ``prompt_cache_key`` for OpenAI-family providers. Caching still + works without it (server-side automatic), but the key improves + backend routing affinity and therefore hit rate. + """ + model_kwargs = _get_or_init_model_kwargs(llm) + if model_kwargs is None: + logger.debug( + "apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping", + type(llm).__name__, + ) + return + + if "cache_control_injection_points" not in model_kwargs: + model_kwargs["cache_control_injection_points"] = [ + dict(point) for point in _DEFAULT_INJECTION_POINTS + ] + + # OpenAI-family extras only when we statically know the destination is + # OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers + # so we can't safely set OpenAI-only kwargs there (drop_params would + # strip them but it's wasteful to set them in the first place). + if _is_router_llm(llm): + return + if not _is_openai_family_config(agent_config): + return + + if thread_id is not None and "prompt_cache_key" not in model_kwargs: + model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}" + if "prompt_cache_retention" not in model_kwargs: + model_kwargs["prompt_cache_retention"] = "24h" diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 4bce79a43..fbd42b458 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -28,6 +28,7 @@ from litellm.exceptions import ( BadRequestError as LiteLLMBadRequestError, ContextWindowExceededError, ) +from pydantic import Field from app.utils.perf import get_perf_logger @@ -573,6 +574,11 @@ class ChatLiteLLMRouter(BaseChatModel): # Public attributes that Pydantic will manage model: str = "auto" streaming: bool = True + # Static kwargs that flow through to ``litellm.completion(...)`` on every + # invocation (e.g. ``cache_control_injection_points`` set by + # ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from + # ``invoke()`` still take precedence — see ``_generate``/``_astream``. + model_kwargs: dict[str, Any] = Field(default_factory=dict) # Bound tools and tool choice for tool calling _bound_tools: list[dict] | None = None @@ -898,13 +904,16 @@ class ChatLiteLLMRouter(BaseChatModel): logger.warning(f"Failed to convert tool {tool}: {e}") continue - # Create a new instance with tools bound + # Create a new instance with tools bound. Carry through ``model_kwargs`` + # so static settings (e.g. cache_control_injection_points) survive the + # bind_tools rebuild. return ChatLiteLLMRouter( router=self._router, bound_tools=formatted_tools if formatted_tools else None, tool_choice=tool_choice, model=self.model, streaming=self.streaming, + model_kwargs=dict(self.model_kwargs), **kwargs, ) @@ -929,8 +938,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -997,8 +1008,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -1060,8 +1073,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -1110,8 +1125,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 131627386..cd683e2e1 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "deepagents>=0.4.12", "stripe>=15.0.0", "azure-ai-documentintelligence>=1.0.2", - "litellm>=1.83.4", + "litellm>=1.83.7", "langchain-litellm>=0.6.4", ] diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index 397b1c787..36fe04aa2 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -226,6 +226,31 @@ class TestCompose: # Default block should NOT be present assert "<knowledge_base_only_policy>" not in prompt + def test_provider_hints_render_with_custom_system_instructions( + self, fixed_today: datetime + ) -> None: + """Regression guard for the always-append decision: provider hints + append AFTER a custom system prompt. + + Provider hints are stylistic nudges (parallel tool-call rules, + formatting guidance, etc.) that help the model regardless of + what the system instructions say. Suppressing them when a + custom prompt is set would partially defeat the per-family + prompt machinery. + """ + prompt = compose_system_prompt( + today=fixed_today, + custom_system_instructions="You are a custom assistant.", + model_name="anthropic/claude-3-5-sonnet", + ) + assert "You are a custom assistant." in prompt + assert "<provider_hints>" in prompt + # The custom prompt must come BEFORE the provider hints so the + # user's framing isn't drowned out by the stylistic nudges. + assert prompt.index("You are a custom assistant.") < prompt.index( + "<provider_hints>" + ) + def test_use_default_false_with_no_custom_yields_no_system_block( self, fixed_today: datetime ) -> None: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py new file mode 100644 index 000000000..5b3a03581 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py @@ -0,0 +1,350 @@ +"""Tests for ``apply_litellm_prompt_caching`` in +:mod:`app.agents.new_chat.prompt_caching`. + +The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which +never activated for our LiteLLM stack) with LiteLLM-native multi-provider +prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to +``litellm.completion(...)``. The tests below pin its public contract: + +1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so + savings compound across multi-turn conversations on Anthropic-family + providers. +2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for + single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic + prompt-cache surface is available). +3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no + OpenAI-only kwargs because the router fans out across providers. +4. Idempotent: user-supplied values in ``model_kwargs`` are preserved. +5. Defensive: LLMs without a writable ``model_kwargs`` are silently + skipped rather than raising. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.agents.new_chat.llm_config import AgentConfig +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class _FakeLLM: + """Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``. + + The helper only inspects ``getattr(llm, "model_kwargs", None)``, + ``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple + object suffices — we don't need to spin up real LangChain/LiteLLM + machinery for unit tests of the helper's logic. + """ + + def __init__( + self, + model: str = "openai/gpt-4o", + model_kwargs: dict[str, Any] | None = None, + ) -> None: + self.model = model + self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {} + + +class ChatLiteLLMRouter: + """Class-name-only impostor of the real router. + + The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"`` + (a deliberate stringly-typed check to avoid an import cycle with + ``app.services.llm_router_service``). Reusing the same class name here + triggers the same code path without instantiating a real ``Router``. + """ + + def __init__(self) -> None: + self.model = "auto" + self.model_kwargs: dict[str, Any] = {} + + +def _make_cfg(**overrides: Any) -> AgentConfig: + """Build an ``AgentConfig`` with sensible defaults for the helper test.""" + defaults: dict[str, Any] = { + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "k", + } + return AgentConfig(**{**defaults, **overrides}) + + +# --------------------------------------------------------------------------- +# (a) Universal injection points +# --------------------------------------------------------------------------- + + +def test_sets_both_cache_control_injection_points_with_no_config() -> None: + """Bare call (no agent_config, no thread_id) still sets the two + universal breakpoints — these cost nothing on providers that don't + consume them and unlock caching on every supported provider.""" + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm) + + points = llm.model_kwargs["cache_control_injection_points"] + assert {"location": "message", "role": "system"} in points + assert {"location": "message", "index": -1} in points + assert len(points) == 2 + + +def test_injection_points_set_for_anthropic_config() -> None: + """Anthropic-family configs need the marker — verify it lands.""" + cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet") + llm = _FakeLLM(model="anthropic/claude-3-5-sonnet") + + apply_litellm_prompt_caching(llm, agent_config=cfg) + + assert "cache_control_injection_points" in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (b) Idempotency / user override wins +# --------------------------------------------------------------------------- + + +def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None: + """Users who set their own injection points (e.g. with ``ttl: "1h"`` + via ``litellm_params``) keep them — the helper merges, never + clobbers.""" + user_points = [ + {"location": "message", "role": "system", "ttl": "1h"}, + ] + llm = _FakeLLM( + model_kwargs={"cache_control_injection_points": user_points}, + ) + + apply_litellm_prompt_caching(llm) + + assert llm.model_kwargs["cache_control_injection_points"] is user_points + + +def test_idempotent_when_called_multiple_times() -> None: + """Build-time + thread-time double-call must be a no-op the second time.""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1) + snapshot = { + "cache_control_injection_points": list( + llm.model_kwargs["cache_control_injection_points"] + ), + "prompt_cache_key": llm.model_kwargs["prompt_cache_key"], + "prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"], + } + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1) + + assert ( + llm.model_kwargs["cache_control_injection_points"] + == snapshot["cache_control_injection_points"] + ) + assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"] + assert ( + llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"] + ) + + +def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None: + """A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via + ``litellm_params``) wins over our default per-thread key.""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"}) + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc" + + +# --------------------------------------------------------------------------- +# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"]) +def test_sets_openai_family_extras(provider: str) -> None: + """OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate + via routing affinity) and ``prompt_cache_retention="24h"`` (extends + cache TTL beyond the default 5-10 min).""" + cfg = _make_cfg(provider=provider) + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42" + assert llm.model_kwargs["prompt_cache_retention"] == "24h" + + +def test_skips_prompt_cache_key_when_no_thread_id() -> None: + """Without a thread id we can't construct a per-thread key. Retention + is still useful so we set it (it's free).""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None) + + assert "prompt_cache_key" not in llm.model_kwargs + assert llm.model_kwargs["prompt_cache_retention"] == "24h" + + +@pytest.mark.parametrize( + "provider", + ["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"], +) +def test_no_openai_extras_for_other_providers(provider: str) -> None: + """Non-OpenAI-family providers don't expose ``prompt_cache_key`` — + skip it. ``cache_control_injection_points`` is still set (universal).""" + cfg = _make_cfg(provider=provider) + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_no_openai_extras_in_auto_mode() -> None: + """Auto-mode fans out across mixed providers — we can't statically + target OpenAI-only kwargs.""" + cfg = AgentConfig.from_auto_mode() + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_no_openai_extras_for_custom_provider() -> None: + """Custom providers route through arbitrary user-supplied prefixes — + we don't try to infer OpenAI-family compatibility.""" + cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (d) ChatLiteLLMRouter — universal injection points only +# --------------------------------------------------------------------------- + + +def test_router_llm_gets_only_universal_injection_points() -> None: + """Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must + receive only the universal injection points — its requests dispatch + across provider deployments and OpenAI-only kwargs would be wasted + (or stripped by ``drop_params``) on non-OpenAI legs.""" + router = ChatLiteLLMRouter() + cfg = _make_cfg(provider="OPENAI") + + apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42) + + assert "cache_control_injection_points" in router.model_kwargs + assert "prompt_cache_key" not in router.model_kwargs + assert "prompt_cache_retention" not in router.model_kwargs + + +# --------------------------------------------------------------------------- +# (e) Defensive paths +# --------------------------------------------------------------------------- + + +def test_handles_llm_with_no_writable_model_kwargs() -> None: + """Some LLM implementations (e.g. fakes / minimal subclasses) don't + expose a writable ``model_kwargs``. The helper must skip silently — + raising would crash the entire LLM build path on a non-critical + optimisation.""" + + class _ImmutableLLM: + # ``__slots__`` blocks attribute creation, so ``setattr`` raises. + __slots__ = ("model",) + + def __init__(self) -> None: + self.model = "openai/gpt-4o" + + llm = _ImmutableLLM() + + apply_litellm_prompt_caching(llm) + + +def test_initialises_missing_model_kwargs_dict() -> None: + """When ``model_kwargs`` is present-but-None (Pydantic v2 default + pattern when no factory is set), the helper initialises it to an + empty dict before mutating.""" + + class _LazyLLM: + def __init__(self) -> None: + self.model = "openai/gpt-4o" + self.model_kwargs: dict[str, Any] | None = None + + llm = _LazyLLM() + + apply_litellm_prompt_caching(llm) + + assert isinstance(llm.model_kwargs, dict) + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None: + """Direct caller path (e.g. ``create_chat_litellm_from_config`` for + YAML configs without a structured ``AgentConfig``): without + ``agent_config`` the helper sets only the universal injection points + — no OpenAI-family extras even if the prefix says ``openai/``. + Conservative: we'd rather miss the speedup than silently misroute.""" + llm = _FakeLLM(model="openai/gpt-4o") + + apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99) + + assert "cache_control_injection_points" in llm.model_kwargs + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (f) drop_params safety net (regression guard for #19346) +# --------------------------------------------------------------------------- + + +def test_litellm_drop_params_is_globally_enabled() -> None: + """``litellm.drop_params=True`` is set globally in + :mod:`app.services.llm_service` so any ``prompt_cache_key`` / + ``prompt_cache_retention`` we set on an OpenAI-family config is + auto-stripped if the request later routes to a non-supporting + provider (e.g. via auto-mode router fallback). This test pins that + invariant — losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``. + """ + import litellm + + import app.services.llm_service # noqa: F401 (side-effect: sets globals) + + assert litellm.drop_params is True + + +# --------------------------------------------------------------------------- +# Regression note: LiteLLM #15696 (multi-content-block last message) +# --------------------------------------------------------------------------- +# +# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]`` +# would get ``cache_control`` applied to *every* content block instead +# of only the last one — wasting cache breakpoints and triggering 400s +# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in +# https://github.com/BerriAI/litellm/pull/15699. +# +# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix). +# An end-to-end behavioural test would need to run ``litellm.completion`` +# through the Anthropic transformer, which is integration territory and +# better covered by LiteLLM's own test suite. The unit guard here is the +# version pin plus the build-time ``model_kwargs`` shape we verify above. diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py new file mode 100644 index 000000000..ffe3dbaa4 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py @@ -0,0 +1,117 @@ +"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`. + +The helper picks the model id fed to ``detect_provider_variant`` so the +right ``<provider_hints>`` block lands in the system prompt. The tests +below pin its preference order: + +1. ``agent_config.litellm_params["base_model"]`` (Azure-correct). +2. ``agent_config.model_name``. +3. ``getattr(llm, "model", None)``. + +Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would +silently miss every provider regex. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name +from app.agents.new_chat.llm_config import AgentConfig + +pytestmark = pytest.mark.unit + + +def _make_cfg(**overrides) -> AgentConfig: + """Build an ``AgentConfig`` with sensible defaults for the helper test.""" + defaults = { + "provider": "OPENAI", + "model_name": "x", + "api_key": "k", + } + return AgentConfig(**{**defaults, **overrides}) + + +class _FakeLLM: + """Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance. + + The resolver only reads the ``.model`` attribute via ``getattr``, + matching the established idiom in ``knowledge_search.py`` / + ``stream_new_chat.py`` / ``document_summarizer.py``. + """ + + def __init__(self, model: str | None) -> None: + self.model = model + + +def test_prefers_litellm_params_base_model_over_deployment_name() -> None: + """Azure deployment slug must NOT shadow the underlying model family. + + This is the failure mode the helper exists to prevent: a deployment + named ``"azure/prod-chat-001"`` would not match any provider regex + on its own, but the family ``"gpt-4o"`` lives in + ``litellm_params["base_model"]`` and routes to ``openai_classic``. + """ + cfg = _make_cfg( + model_name="azure/prod-chat-001", + litellm_params={"base_model": "gpt-4o"}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o" + + +def test_falls_back_to_model_name_when_litellm_params_is_none() -> None: + cfg = _make_cfg( + model_name="anthropic/claude-3-5-sonnet", + litellm_params=None, + ) + got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet")) + assert got == "anthropic/claude-3-5-sonnet" + + +def test_handles_litellm_params_without_base_model_key() -> None: + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"temperature": 0.5}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_ignores_blank_base_model() -> None: + """Whitespace-only ``base_model`` must not shadow ``model_name``.""" + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"base_model": " "}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_ignores_non_string_base_model() -> None: + """Defensive: a non-string ``base_model`` should not crash the resolver.""" + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"base_model": 42}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_falls_back_to_llm_model_when_no_agent_config() -> None: + """No ``agent_config`` -> use ``llm.model`` directly. Defensive path + for direct callers; production callers always supply a config.""" + assert ( + _resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini")) + == "openai/gpt-4o-mini" + ) + + +def test_returns_none_when_nothing_available() -> None: + """``compose_system_prompt`` treats ``None`` as the ``"default"`` + variant and emits no provider block.""" + assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None + + +def test_auto_mode_resolves_to_auto_string() -> None: + """Auto mode -> ``"auto"``. ``detect_provider_variant("auto")`` + returns ``"default"``, which is correct: the child model isn't + known until the LiteLLM Router dispatches.""" + cfg = AgentConfig.from_auto_mode() + assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto" diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 5e6ad6abd..5935d73ae 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -372,3 +372,39 @@ def test_turn_status_sse_contract_exists(): assert 'type: "data-turn-status"' in state_source assert 'case "data-turn-status":' in pipeline_source assert "end_turn(str(chat_id))" in stream_source + + +def test_chat_deepagent_forwards_resolved_model_name_to_both_builders(): + """Regression guard: both system-prompt builders in chat_deepagent.py + must receive ``model_name=_resolve_prompt_model_name(...)`` so the + provider-variant dispatch can render the right ``<provider_hints>`` + block. Without this the prompt silently falls back to the empty + ``"default"`` variant — the original bug being fixed. + + This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes` + in style: it inspects module source text + a regex to enforce the + call-site shape, not just the wrapper layer (the wrappers already + forward ``model_name`` correctly, so testing them would not catch + the actual missed plumbing). + """ + import app.agents.new_chat.chat_deepagent as chat_deepagent_module + + source = inspect.getsource(chat_deepagent_module) + + # Helper itself must be defined. + assert "def _resolve_prompt_model_name(" in source + + # Both builder calls must forward the resolved model name. Match + # across newlines + whitespace because the kwargs are split over + # multiple lines. + pattern = re.compile( + r"build_(?:surfsense|configurable)_system_prompt\([^)]*" + r"model_name=_resolve_prompt_model_name\(", + re.DOTALL, + ) + matches = pattern.findall(source) + assert len(matches) == 2, ( + "Expected both system-prompt builder call sites to forward " + "`model_name=_resolve_prompt_model_name(...)`, found " + f"{len(matches)}" + ) diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 209c42a9c..efe670d05 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -62,7 +62,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.5" +version = "3.13.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -73,76 +73,76 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271 } +sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748 } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876 }, - { url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557 }, - { url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258 }, - { url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199 }, - { url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013 }, - { url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501 }, - { url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981 }, - { url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934 }, - { url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671 }, - { url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219 }, - { url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049 }, - { url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557 }, - { url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931 }, - { url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125 }, - { url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427 }, - { url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534 }, - { url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446 }, - { url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930 }, - { url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927 }, - { url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141 }, - { url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476 }, - { url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507 }, - { url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465 }, - { url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523 }, - { url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113 }, - { url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351 }, - { url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205 }, - { url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618 }, - { url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185 }, - { url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311 }, - { url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147 }, - { url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356 }, - { url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637 }, - { url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896 }, - { url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721 }, - { url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663 }, - { url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094 }, - { url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701 }, - { url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360 }, - { url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023 }, - { url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795 }, - { url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405 }, - { url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082 }, - { url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346 }, - { url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891 }, - { url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113 }, - { url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088 }, - { url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976 }, - { url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444 }, - { url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128 }, - { url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029 }, - { url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758 }, - { url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883 }, - { url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668 }, - { url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461 }, - { url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661 }, - { url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800 }, - { url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382 }, - { url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724 }, - { url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027 }, - { url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644 }, - { url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630 }, - { url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403 }, - { url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924 }, - { url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119 }, - { url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072 }, - { url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819 }, - { url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441 }, + { url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158 }, + { url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037 }, + { url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556 }, + { url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314 }, + { url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819 }, + { url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279 }, + { url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082 }, + { url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938 }, + { url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548 }, + { url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669 }, + { url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175 }, + { url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049 }, + { url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861 }, + { url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003 }, + { url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289 }, + { url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185 }, + { url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285 }, + { url = "https://files.pythonhosted.org/packages/e3/ac/892f4162df9b115b4758d615f32ec63d00f3084c705ff5526630887b9b42/aiohttp-3.13.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:63dd5e5b1e43b8fb1e91b79b7ceba1feba588b317d1edff385084fcc7a0a4538", size = 745744 }, + { url = "https://files.pythonhosted.org/packages/97/a9/c5b87e4443a2f0ea88cb3000c93a8fdad1ee63bffc9ded8d8c8e0d66efc6/aiohttp-3.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:746ac3cc00b5baea424dacddea3ec2c2702f9590de27d837aa67004db1eebc6e", size = 498178 }, + { url = "https://files.pythonhosted.org/packages/94/42/07e1b543a61250783650df13da8ddcdc0d0a5538b2bd15cef6e042aefc61/aiohttp-3.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bda8f16ea99d6a6705e5946732e48487a448be874e54a4f73d514660ff7c05d3", size = 498331 }, + { url = "https://files.pythonhosted.org/packages/20/d6/492f46bf0328534124772d0cf58570acae5b286ea25006900650f69dae0e/aiohttp-3.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b061e7b5f840391e3f64d0ddf672973e45c4cfff7a0feea425ea24e51530fc2", size = 1744414 }, + { url = "https://files.pythonhosted.org/packages/e2/4d/e02627b2683f68051246215d2d62b2d2f249ff7a285e7a858dc47d6b6a14/aiohttp-3.13.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b252e8d5cd66184b570d0d010de742736e8a4fab22c58299772b0c5a466d4b21", size = 1719226 }, + { url = "https://files.pythonhosted.org/packages/7b/6c/5d0a3394dd2b9f9aeba6e1b6065d0439e4b75d41f1fb09a3ec010b43552b/aiohttp-3.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20af8aad61d1803ff11152a26146d8d81c266aa8c5aa9b4504432abb965c36a0", size = 1782110 }, + { url = "https://files.pythonhosted.org/packages/0d/2d/c20791e3437700a7441a7edfb59731150322424f5aadf635602d1d326101/aiohttp-3.13.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:13a5cc924b59859ad2adb1478e31f410a7ed46e92a2a619d6d1dd1a63c1a855e", size = 1884809 }, + { url = "https://files.pythonhosted.org/packages/c8/94/d99dbfbd1924a87ef643833932eb2a3d9e5eee87656efea7d78058539eff/aiohttp-3.13.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:534913dfb0a644d537aebb4123e7d466d94e3be5549205e6a31f72368980a81a", size = 1764938 }, + { url = "https://files.pythonhosted.org/packages/49/61/3ce326a1538781deb89f6cf5e094e2029cd308ed1e21b2ba2278b08426f6/aiohttp-3.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:320e40192a2dcc1cf4b5576936e9652981ab596bf81eb309535db7e2f5b5672f", size = 1570697 }, + { url = "https://files.pythonhosted.org/packages/b6/77/4ab5a546857bb3028fbaf34d6eea180267bdab022ee8b1168b1fcde4bfdd/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9e587fcfce2bcf06526a43cb705bdee21ac089096f2e271d75de9c339db3100c", size = 1702258 }, + { url = "https://files.pythonhosted.org/packages/79/63/d8f29021e39bc5af8e5d5e9da1b07976fb9846487a784e11e4f4eeda4666/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9eb9c2eea7278206b5c6c1441fdd9dc420c278ead3f3b2cc87f9b693698cc500", size = 1740287 }, + { url = "https://files.pythonhosted.org/packages/55/3a/cbc6b3b124859a11bc8055d3682c26999b393531ef926754a3445b99dfef/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:29be00c51972b04bf9d5c8f2d7f7314f48f96070ca40a873a53056e652e805f7", size = 1753011 }, + { url = "https://files.pythonhosted.org/packages/e0/30/836278675205d58c1368b21520eab9572457cf19afd23759216c04483048/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:90c06228a6c3a7c9f776fe4fc0b7ff647fffd3bed93779a6913c804ae00c1073", size = 1566359 }, + { url = "https://files.pythonhosted.org/packages/50/b4/8032cc9b82d17e4277704ba30509eaccb39329dc18d6a35f05e424439e32/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:a533ec132f05fd9a1d959e7f34184cd7d5e8511584848dab85faefbaac573069", size = 1785537 }, + { url = "https://files.pythonhosted.org/packages/17/7d/5873e98230bde59f493bf1f7c3e327486a4b5653fa401144704df5d00211/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1c946f10f413836f82ea4cfb90200d2a59578c549f00857e03111cf45ad01ca5", size = 1740752 }, + { url = "https://files.pythonhosted.org/packages/7b/f2/13e46e0df051494d7d3c68b7f72d071f48c384c12716fc294f75d5b1a064/aiohttp-3.13.4-cp313-cp313-win32.whl", hash = "sha256:48708e2706106da6967eff5908c78ca3943f005ed6bcb75da2a7e4da94ef8c70", size = 433187 }, + { url = "https://files.pythonhosted.org/packages/ea/c0/649856ee655a843c8f8664592cfccb73ac80ede6a8c8db33a25d810c12db/aiohttp-3.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:74a2eb058da44fa3a877a49e2095b591d4913308bb424c418b77beb160c55ce3", size = 459778 }, + { url = "https://files.pythonhosted.org/packages/6d/29/6657cc37ae04cacc2dbf53fb730a06b6091cc4cbe745028e047c53e6d840/aiohttp-3.13.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:e0a2c961fc92abeff61d6444f2ce6ad35bb982db9fc8ff8a47455beacf454a57", size = 749363 }, + { url = "https://files.pythonhosted.org/packages/90/7f/30ccdf67ca3d24b610067dc63d64dcb91e5d88e27667811640644aa4a85d/aiohttp-3.13.4-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:153274535985a0ff2bff1fb6c104ed547cec898a09213d21b0f791a44b14d933", size = 499317 }, + { url = "https://files.pythonhosted.org/packages/93/13/e372dd4e68ad04ee25dafb050c7f98b0d91ea643f7352757e87231102555/aiohttp-3.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:351f3171e2458da3d731ce83f9e6b9619e325c45cbd534c7759750cabf453ad7", size = 500477 }, + { url = "https://files.pythonhosted.org/packages/e5/fe/ee6298e8e586096fb6f5eddd31393d8544f33ae0792c71ecbb4c2bef98ac/aiohttp-3.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f989ac8bc5595ff761a5ccd32bdb0768a117f36dd1504b1c2c074ed5d3f4df9c", size = 1737227 }, + { url = "https://files.pythonhosted.org/packages/b0/b9/a7a0463a09e1a3fe35100f74324f23644bfc3383ac5fd5effe0722a5f0b7/aiohttp-3.13.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d36fc1709110ec1e87a229b201dd3ddc32aa01e98e7868083a794609b081c349", size = 1694036 }, + { url = "https://files.pythonhosted.org/packages/57/7c/8972ae3fb7be00a91aee6b644b2a6a909aedb2c425269a3bfd90115e6f8f/aiohttp-3.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42adaeea83cbdf069ab94f5103ce0787c21fb1a0153270da76b59d5578302329", size = 1786814 }, + { url = "https://files.pythonhosted.org/packages/93/01/c81e97e85c774decbaf0d577de7d848934e8166a3a14ad9f8aa5be329d28/aiohttp-3.13.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:92deb95469928cc41fd4b42a95d8012fa6df93f6b1c0a83af0ffbc4a5e218cde", size = 1866676 }, + { url = "https://files.pythonhosted.org/packages/5a/5f/5b46fe8694a639ddea2cd035bf5729e4677ea882cb251396637e2ef1590d/aiohttp-3.13.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0c7c07c4257ef3a1df355f840bc62d133bcdef5c1c5ba75add3c08553e2eed", size = 1740842 }, + { url = "https://files.pythonhosted.org/packages/20/a2/0d4b03d011cca6b6b0acba8433193c1e484efa8d705ea58295590fe24203/aiohttp-3.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f062c45de8a1098cb137a1898819796a2491aec4e637a06b03f149315dff4d8f", size = 1566508 }, + { url = "https://files.pythonhosted.org/packages/98/17/e689fd500da52488ec5f889effd6404dece6a59de301e380f3c64f167beb/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:76093107c531517001114f0ebdb4f46858ce818590363e3e99a4a2280334454a", size = 1700569 }, + { url = "https://files.pythonhosted.org/packages/d8/0d/66402894dbcf470ef7db99449e436105ea862c24f7ea4c95c683e635af35/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:6f6ec32162d293b82f8b63a16edc80769662fbd5ae6fbd4936d3206a2c2cc63b", size = 1707407 }, + { url = "https://files.pythonhosted.org/packages/2f/eb/af0ab1a3650092cbd8e14ef29e4ab0209e1460e1c299996c3f8288b3f1ff/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5903e2db3d202a00ad9f0ec35a122c005e85d90c9836ab4cda628f01edf425e2", size = 1752214 }, + { url = "https://files.pythonhosted.org/packages/5a/bf/72326f8a98e4c666f292f03c385545963cc65e358835d2a7375037a97b57/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2d5bea57be7aca98dbbac8da046d99b5557c5cf4e28538c4c786313078aca09e", size = 1562162 }, + { url = "https://files.pythonhosted.org/packages/67/9f/13b72435f99151dd9a5469c96b3b5f86aa29b7e785ca7f35cf5e538f74c0/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:bcf0c9902085976edc0232b75006ef38f89686901249ce14226b6877f88464fb", size = 1768904 }, + { url = "https://files.pythonhosted.org/packages/18/bc/28d4970e7d5452ac7776cdb5431a1164a0d9cf8bd2fffd67b4fb463aa56d/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3295f98bfeed2e867cab588f2a146a9db37a85e3ae9062abf46ba062bd29165", size = 1723378 }, + { url = "https://files.pythonhosted.org/packages/53/74/b32458ca1a7f34d65bdee7aef2036adbe0438123d3d53e2b083c453c24dd/aiohttp-3.13.4-cp314-cp314-win32.whl", hash = "sha256:a598a5c5767e1369d8f5b08695cab1d8160040f796c4416af76fd773d229b3c9", size = 438711 }, + { url = "https://files.pythonhosted.org/packages/40/b2/54b487316c2df3e03a8f3435e9636f8a81a42a69d942164830d193beb56a/aiohttp-3.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:c555db4bc7a264bead5a7d63d92d41a1122fcd39cc62a4db815f45ad46f9c2c8", size = 464977 }, + { url = "https://files.pythonhosted.org/packages/47/fb/e41b63c6ce71b07a59243bb8f3b457ee0c3402a619acb9d2c0d21ef0e647/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45abbbf09a129825d13c18c7d3182fecd46d9da3cfc383756145394013604ac1", size = 781549 }, + { url = "https://files.pythonhosted.org/packages/97/53/532b8d28df1e17e44c4d9a9368b78dcb6bf0b51037522136eced13afa9e8/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:74c80b2bc2c2adb7b3d1941b2b60701ee2af8296fc8aad8b8bc48bc25767266c", size = 514383 }, + { url = "https://files.pythonhosted.org/packages/1b/1f/62e5d400603e8468cd635812d99cb81cfdc08127a3dc474c647615f31339/aiohttp-3.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c97989ae40a9746650fa196894f317dafc12227c808c774929dda0ff873a5954", size = 518304 }, + { url = "https://files.pythonhosted.org/packages/90/57/2326b37b10896447e3c6e0cbef4fe2486d30913639a5cfd1332b5d870f82/aiohttp-3.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dae86be9811493f9990ef44fff1685f5c1a3192e9061a71a109d527944eed551", size = 1893433 }, + { url = "https://files.pythonhosted.org/packages/d2/b4/a24d82112c304afdb650167ef2fe190957d81cbddac7460bedd245f765aa/aiohttp-3.13.4-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1db491abe852ca2fa6cc48a3341985b0174b3741838e1341b82ac82c8bd9e871", size = 1755901 }, + { url = "https://files.pythonhosted.org/packages/9e/2d/0883ef9d878d7846287f036c162a951968f22aabeef3ac97b0bea6f76d5d/aiohttp-3.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e5d701c0aad02a7dce72eef6b93226cf3734330f1a31d69ebbf69f33b86666e", size = 1876093 }, + { url = "https://files.pythonhosted.org/packages/ad/52/9204bb59c014869b71971addad6778f005daa72a96eed652c496789d7468/aiohttp-3.13.4-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8ac32a189081ae0a10ba18993f10f338ec94341f0d5df8fff348043962f3c6f8", size = 1970815 }, + { url = "https://files.pythonhosted.org/packages/d6/b5/e4eb20275a866dde0f570f411b36c6b48f7b53edfe4f4071aa1b0728098a/aiohttp-3.13.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98e968cdaba43e45c73c3f306fca418c8009a957733bac85937c9f9cf3f4de27", size = 1816223 }, + { url = "https://files.pythonhosted.org/packages/d8/23/e98075c5bb146aa61a1239ee1ac7714c85e814838d6cebbe37d3fe19214a/aiohttp-3.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca114790c9144c335d538852612d3e43ea0f075288f4849cf4b05d6cd2238ce7", size = 1649145 }, + { url = "https://files.pythonhosted.org/packages/d6/c1/7bad8be33bb06c2bb224b6468874346026092762cbec388c3bdb65a368ee/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ea2e071661ba9cfe11eabbc81ac5376eaeb3061f6e72ec4cc86d7cdd1ffbdbbb", size = 1816562 }, + { url = "https://files.pythonhosted.org/packages/5c/10/c00323348695e9a5e316825969c88463dcc24c7e9d443244b8a2c9cf2eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:34e89912b6c20e0fd80e07fa401fd218a410aa1ce9f1c2f1dad6db1bd0ce0927", size = 1800333 }, + { url = "https://files.pythonhosted.org/packages/84/43/9b2147a1df3559f49bd723e22905b46a46c068a53adb54abdca32c4de180/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0e217cf9f6a42908c52b46e42c568bd57adc39c9286ced31aaace614b6087965", size = 1820617 }, + { url = "https://files.pythonhosted.org/packages/a9/7f/b3481a81e7a586d02e99387b18c6dafff41285f6efd3daa2124c01f87eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:0c296f1221e21ba979f5ac1964c3b78cfde15c5c5f855ffd2caab337e9cd9182", size = 1643417 }, + { url = "https://files.pythonhosted.org/packages/8f/72/07181226bc99ce1124e0f89280f5221a82d3ae6a6d9d1973ce429d48e52b/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d99a9d168ebaffb74f36d011750e490085ac418f4db926cce3989c8fe6cb6b1b", size = 1849286 }, + { url = "https://files.pythonhosted.org/packages/1a/e6/1b3566e103eca6da5be4ae6713e112a053725c584e96574caf117568ffef/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cb19177205d93b881f3f89e6081593676043a6828f59c78c17a0fd6c1fbed2ba", size = 1782635 }, + { url = "https://files.pythonhosted.org/packages/37/58/1b11c71904b8d079eb0c39fe664180dd1e14bebe5608e235d8bfbadc8929/aiohttp-3.13.4-cp314-cp314t-win32.whl", hash = "sha256:c606aa5656dab6552e52ca368e43869c916338346bfaf6304e15c58fb113ea30", size = 472537 }, + { url = "https://files.pythonhosted.org/packages/bc/8f/87c56a1a1977d7dddea5b31e12189665a140fdb48a71e9038ff90bb564ec/aiohttp-3.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:014dcc10ec8ab8db681f0d68e939d1e9286a5aa2b993cbbdb0db130853e02144", size = 506381 }, ] [[package]] @@ -3723,7 +3723,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.83.4" +version = "1.83.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3739,9 +3739,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/03/c4/30469c06ae7437a4406bc11e3c433cfd380a6771068cca15ea918dcd158f/litellm-1.83.4.tar.gz", hash = "sha256:6458d2030a41229460b321adee00517a91dbd8e63213cc953d355cb41d16f2d4", size = 17733899 } +sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/bd/df19d3f8f6654535ee343a341fd921f81c411abf601a53e3eaef58129b02/litellm-1.83.4-py3-none-any.whl", hash = "sha256:17d7b4d48d47aca988ea4f762ddda5e7bd72cda3270192b22813d0330869d7b4", size = 16015555 }, + { url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054 }, ] [[package]] @@ -5124,7 +5124,7 @@ wheels = [ [[package]] name = "openai" -version = "2.30.0" +version = "2.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -5136,9 +5136,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084 } +sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656 }, + { url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122 }, ] [[package]] @@ -6780,11 +6780,11 @@ wheels = [ [[package]] name = "python-dotenv" -version = "1.0.1" +version = "1.2.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101 }, ] [[package]] @@ -8070,7 +8070,7 @@ requires-dist = [ { name = "langgraph", specifier = ">=1.1.3" }, { name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" }, { name = "linkup-sdk", specifier = ">=0.2.4" }, - { name = "litellm", specifier = ">=1.83.4" }, + { name = "litellm", specifier = ">=1.83.7" }, { name = "llama-cloud-services", specifier = ">=0.6.25" }, { name = "markdown", specifier = ">=3.7" }, { name = "markdownify", specifier = ">=0.14.1" }, diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 416fd8633..175cae4ab 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -17,7 +17,6 @@ const demoPlans = [ "Self Hostable", "500 pages included to start", "3 million premium tokens to start", - "Earn up to 3,000+ bonus pages for free", "Includes access to OpenAI text, audio and image models", "Realtime Collaborative Group Chats with teammates", "Community support on Discord", diff --git a/surfsense_web/components/settings/more-pages-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx index 944f7418f..8de61b0c7 100644 --- a/surfsense_web/components/settings/more-pages-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -1,21 +1,14 @@ "use client"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import { Check, ExternalLink, Mail } from "lucide-react"; +import { Check, ExternalLink } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; -import { useEffect, useState } from "react"; +import { useEffect } from "react"; import { toast } from "sonner"; import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; import { Separator } from "@/components/ui/separator"; import { Skeleton } from "@/components/ui/skeleton"; import { Spinner } from "@/components/ui/spinner"; @@ -33,7 +26,6 @@ export function MorePagesContent() { const params = useParams(); const queryClient = useQueryClient(); const searchSpaceId = params?.search_space_id ?? ""; - const [claimOpen, setClaimOpen] = useState(false); useEffect(() => { trackIncentivePageViewed(); @@ -79,35 +71,10 @@ export function MorePagesContent() { <div className="text-center"> <h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2> <p className="mt-1 text-sm text-muted-foreground"> - Claim your free page offer and earn bonus pages + Earn bonus pages by completing tasks </p> </div> - {/* 3k free offer */} - <Card className="border-emerald-500/30 bg-emerald-500/5"> - <CardContent className="flex items-center gap-3 p-4"> - <div className="flex h-10 w-10 shrink-0 items-center justify-center rounded-full bg-emerald-600 text-white text-xs font-bold"> - 3k - </div> - <div className="min-w-0 flex-1"> - <p className="text-sm font-semibold">Claim 3,000 Free Pages</p> - <p className="text-xs text-muted-foreground"> - Limited offer. Schedule a meeting or email us to claim. - </p> - </div> - <Button - size="sm" - className="bg-emerald-600 text-white hover:bg-emerald-700" - onClick={() => setClaimOpen(true)} - > - Claim - </Button> - </CardContent> - </Card> - - <Separator /> - - {/* Free tasks */} <div className="space-y-2"> <h3 className="text-sm font-semibold">Earn Bonus Pages</h3> {isLoading ? ( @@ -182,7 +149,6 @@ export function MorePagesContent() { <Separator /> - {/* Link to buy pages */} <div className="text-center"> <p className="text-sm text-muted-foreground">Need more?</p> {pageBuyingEnabled ? ( @@ -197,25 +163,6 @@ export function MorePagesContent() { </p> )} </div> - - {/* Claim 3k dialog */} - <Dialog open={claimOpen} onOpenChange={setClaimOpen}> - <DialogContent className="sm:max-w-md"> - <DialogHeader> - <DialogTitle>Claim 3,000 Free Pages</DialogTitle> - <DialogDescription> - Send us an email to claim your free 3,000 pages. Include your account email and - primary usecase for free pages. - </DialogDescription> - </DialogHeader> - <Button asChild className="w-full gap-2"> - <a href="mailto:rohan@surfsense.com?subject=Claim%203%2C000%20Free%20Pages&body=Hi%2C%20I'd%20like%20to%20claim%20the%203%2C000%20free%20pages%20offer.%0A%0AMy%20account%20email%3A%20"> - <Mail className="h-4 w-4" /> - rohan@surfsense.com - </a> - </Button> - </DialogContent> - </Dialog> </div> ); } From 5dd45a5740156a96018ca560f5f0b91886879830 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:41:52 +0530 Subject: [PATCH 267/299] refactor(router): add router_pool_eligible filter and rebuild() API --- .../app/services/llm_router_service.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 4bce79a43..d624ff56c 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -207,6 +207,12 @@ class LLMRouterService: """ Initialize the router with global LLM configurations. + Configs with ``router_pool_eligible=False`` are skipped so that + dynamic OpenRouter entries stay out of the shared router pool used + by title-gen / sub-agent ``model="auto"`` flows. Those dynamic + entries are still available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + Args: global_configs: List of global LLM config dictionaries from YAML router_settings: Optional router settings (routing_strategy, num_retries, etc.) @@ -220,6 +226,8 @@ class LLMRouterService: model_list = [] premium_models: set[str] = set() for config in global_configs: + if config.get("router_pool_eligible") is False: + continue deployment = cls._config_to_deployment(config) if deployment: model_list.append(deployment) @@ -308,10 +316,45 @@ class LLMRouterService: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None + @classmethod + def rebuild( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + """Reset the router and re-run ``initialize`` with fresh configs. + + ``initialize`` short-circuits once it has run to avoid re-creating the + LiteLLM Router on every request; ``rebuild`` deliberately clears + ``_initialized`` so a caller (e.g. background OpenRouter refresh) + can force the pool to be rebuilt after catalogue changes. + """ + instance = cls.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + cls.initialize(global_configs, router_settings) + @classmethod def is_premium_model(cls, model_string: str) -> bool: - """Return True if *model_string* (as reported by LiteLLM) belongs to a - premium-tier deployment in the router pool.""" + """Return True if *model_string* belongs to a premium-tier deployment + in the LiteLLM router pool. + + Scope: only covers configs with ``router_pool_eligible`` truthy. That + includes static YAML premium configs AND dynamic OpenRouter *premium* + entries (which opt in at generation time). Dynamic OpenRouter *free* + entries and the virtual ``openrouter/free`` router are deliberately + kept out of the router pool — OpenRouter enforces free-tier limits + globally per account, so per-deployment router accounting can't + represent them correctly — and therefore return ``False`` here, which + matches their ``billing_tier="free"`` (no premium quota). + + For per-request premium checks on an arbitrary config (static or + dynamic, pool or non-pool), read ``agent_config.is_premium`` instead; + that reflects the per-config ``billing_tier`` directly and is what + user-facing Auto-mode thread pinning uses to bill correctly. + """ instance = cls.get_instance() return model_string in instance._premium_model_strings From ccd7caf99f14411dffe5067cd3171357ab690808 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:42:21 +0530 Subject: [PATCH 268/299] feat(openrouter): derive billing tier per-model and stabilize config IDs --- .../openrouter_integration_service.py | 191 ++++++++++++++++-- 1 file changed, 173 insertions(+), 18 deletions(-) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 1245f73aa..2d6a42337 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -11,6 +11,7 @@ this service only manages the catalogue, not the inference path. """ import asyncio +import hashlib import logging import threading from typing import Any @@ -25,6 +26,56 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" # dynamic OpenRouter entries from hand-written YAML entries during refresh. _OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__" +# Fixed negative ID for the virtual ``openrouter/free`` auto-select entry. +# Chosen to sit far below any reasonable ``id_offset`` so it never collides +# with per-model stable IDs. +_FREE_ROUTER_ID = -9_999_999 + +# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides +# enough headroom to avoid frequent collisions for OpenRouter's catalogue +# (~300 models) while keeping IDs comfortably within Postgres INTEGER range. +_STABLE_ID_HASH_WIDTH = 9_000_000 + + +def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int: + """Derive a deterministic negative config ID from ``model_id``. + + The same ``model_id`` always hashes to the same base value so thread pins + survive catalogue churn (models appearing/disappearing/reordering between + refreshes). On collision we decrement until we find an unused slot; this + keeps the mapping stable for the first config that claimed a slot and + only shifts collisions, which is much less disruptive than the legacy + index-based scheme that reshuffled every ID when the catalogue changed. + """ + digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest() + base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH) + cid = base + while cid in taken: + cid -= 1 + taken.add(cid) + return cid + + +def _openrouter_tier(model: dict) -> str: + """Classify an OpenRouter model as ``"free"`` or ``"premium"``. + + Per OpenRouter's API contract, a model is free if: + - Its id ends with ``:free`` (OpenRouter's own free-variant convention), or + - Both ``pricing.prompt`` and ``pricing.completion`` are zero strings. + + Anything else (missing pricing, non-zero pricing) falls through to + ``"premium"`` so we never under-charge users. This derivation runs off the + already-cached /api/v1/models payload, so it adds no network cost. + """ + if model.get("id", "").endswith(":free"): + return "free" + pricing = model.get("pricing") or {} + prompt = str(pricing.get("prompt", "")).strip() + completion = str(pricing.get("completion", "")).strip() + if prompt == "0" and completion == "0": + return "free" + return "premium" + def _is_text_output_model(model: dict) -> bool: """Return True if the model produces text output only (skip image/audio generators).""" @@ -109,24 +160,77 @@ async def _fetch_models_async() -> list[dict] | None: return None +def _build_free_router_config(settings: dict[str, Any]) -> dict[str, Any]: + """Build the virtual ``openrouter/free`` auto-select config entry. + + This exposes OpenRouter's Free Models Router as a single selectable + option. LiteLLM forwards ``openrouter/openrouter/free`` and OpenRouter + picks a capable free model per request (availability varies, account-wide + rate limit is ~20 req/min). + """ + return { + "id": _FREE_ROUTER_ID, + "name": "OpenRouter Free (Auto-Select)", + "description": ( + "OpenRouter picks a capable free model per request. " + "~20 req/min account-wide; availability varies." + ), + "provider": "OPENROUTER", + "model_name": "openrouter/free", + "api_key": settings.get("api_key", ""), + "api_base": "", + "billing_tier": "free", + "rpm": settings.get("free_rpm", 20), + "tpm": settings.get("free_tpm", 100_000), + "anonymous_enabled": settings.get("anonymous_enabled_free", False), + "seo_enabled": False, + "seo_slug": None, + "quota_reserve_tokens": settings.get("quota_reserve_tokens", 4000), + "litellm_params": dict(settings.get("litellm_params") or {}), + "system_instructions": settings.get("system_instructions", ""), + "use_default_system_instructions": settings.get( + "use_default_system_instructions", True + ), + "citations_enabled": settings.get("citations_enabled", True), + "router_pool_eligible": False, + _OPENROUTER_DYNAMIC_MARKER: True, + } + + def _generate_configs( raw_models: list[dict], settings: dict[str, Any], ) -> list[dict]: - """ - Convert raw OpenRouter model entries into global LLM config dicts. + """Convert raw OpenRouter model entries into global LLM config dicts. - Models are sorted by ID for deterministic, stable ID assignment across - restarts and refreshes. + Tier (``billing_tier``) is derived per-model from OpenRouter's own API + signals via ``_openrouter_tier`` — there is no longer a uniform YAML + override. Config IDs are derived via ``_stable_config_id`` so they + survive catalogue churn across refreshes. + + Router-pool membership is tier-aware: + + - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``) + so sub-agent ``model="auto"`` flows benefit from load balancing and + failover across the curated YAML configs and the OR premium passthrough. + - Free OR models and the virtual ``openrouter/free`` entry stay excluded + (``router_pool_eligible=False``). LiteLLM Router tracks rate limits per + deployment, but OpenRouter enforces a single global free-tier quota + (~20 RPM + 50-1000 daily requests account-wide across every ``:free`` + model), so rotating across many free deployments would only burn the + shared bucket faster. Free OR models remain fully available for user- + facing Auto-mode thread pinning via ``auto_model_pin_service``. """ id_offset: int = settings.get("id_offset", -10000) api_key: str = settings.get("api_key", "") - billing_tier: str = settings.get("billing_tier", "premium") - anonymous_enabled: bool = settings.get("anonymous_enabled", False) seo_enabled: bool = settings.get("seo_enabled", False) quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1000000) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + anon_paid: bool = settings.get("anonymous_enabled_paid", False) + anon_free: bool = settings.get("anonymous_enabled_free", False) litellm_params: dict = settings.get("litellm_params") or {} system_instructions: str = settings.get("system_instructions", "") use_default: bool = settings.get("use_default_system_instructions", True) @@ -142,19 +246,27 @@ def _generate_configs( and _is_allowed_model(m) and "/" in m.get("id", "") ] - text_models.sort(key=lambda m: m["id"]) configs: list[dict] = [] - for idx, model in enumerate(text_models): + + if settings.get("free_router_enabled", True) and api_key: + configs.append(_build_free_router_config(settings)) + + taken: set[int] = set() + if configs: + taken.add(_FREE_ROUTER_ID) + + for model in text_models: model_id: str = model["id"] name: str = model.get("name", model_id) + tier = _openrouter_tier(model) cfg: dict[str, Any] = { - "id": id_offset - idx, + "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter", - "billing_tier": billing_tier, - "anonymous_enabled": anonymous_enabled, + "billing_tier": tier, + "anonymous_enabled": anon_free if tier == "free" else anon_paid, "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, @@ -162,12 +274,18 @@ def _generate_configs( "model_name": model_id, "api_key": api_key, "api_base": "", - "rpm": rpm, - "tpm": tpm, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), "system_instructions": system_instructions, "use_default_system_instructions": use_default, "citations_enabled": citations_enabled, + # Premium OR deployments join the LiteLLM router pool so sub-agent + # model="auto" flows can load-balance / fail over across them. + # Free OR deployments stay out: OpenRouter's free tier is a single + # account-wide quota, so per-deployment routing can't spread load + # there — it just drains the shared bucket faster. + "router_pool_eligible": tier == "premium", _OPENROUTER_DYNAMIC_MARKER: True, } configs.append(cfg) @@ -220,11 +338,12 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._initialized = True + tier_counts = self._tier_counts(self._configs) logger.info( - "OpenRouter integration: loaded %d models (IDs %d to %d)", + "OpenRouter integration: loaded %d models (free=%d, premium=%d)", len(self._configs), - self._configs[0]["id"] if self._configs else 0, - self._configs[-1]["id"] if self._configs else 0, + tier_counts["free"], + tier_counts["premium"], ) return self._configs @@ -254,7 +373,43 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - logger.info("OpenRouter refresh: updated to %d models", len(new_configs)) + tier_counts = self._tier_counts(new_configs) + logger.info( + "OpenRouter refresh: updated to %d models (free=%d, premium=%d)", + len(new_configs), + tier_counts["free"], + tier_counts["premium"], + ) + + # Rebuild the LiteLLM router so freshly fetched configs flow through + # (the router filters dynamic OR entries out of its pool, but a + # refresh still needs to pick up any static-config edits and reset + # cached context-window profiles). + try: + from app.config import config as _app_config + from app.services.llm_router_service import LLMRouterService + from app.services.llm_router_service import ( + _router_instance_cache as _chat_router_cache, + ) + + LLMRouterService.rebuild( + _app_config.GLOBAL_LLM_CONFIGS, + getattr(_app_config, "ROUTER_SETTINGS", None), + ) + _chat_router_cache.clear() + except Exception as exc: + logger.warning( + "OpenRouter refresh: router rebuild skipped (%s)", exc + ) + + @staticmethod + def _tier_counts(configs: list[dict]) -> dict[str, int]: + counts = {"free": 0, "premium": 0} + for cfg in configs: + tier = str(cfg.get("billing_tier", "")).lower() + if tier in counts: + counts[tier] += 1 + return counts async def _refresh_loop(self, interval_hours: float) -> None: interval_sec = interval_hours * 3600 From 925c33abd18424d5d0837ccea8ca0288fd5a6c44 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:42:44 +0530 Subject: [PATCH 269/299] chore(config): deprecate billing_tier / anonymous_enabled, split anon flags --- surfsense_backend/app/config/__init__.py | 50 ++++++++++++++++--- .../app/config/global_llm_config.example.yaml | 50 ++++++++++++++----- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index bd97d2bb1..11cbe24a7 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -194,6 +194,9 @@ def load_openrouter_integration_settings() -> dict | None: """ Load OpenRouter integration settings from the YAML config. + Emits startup warnings for deprecated keys (``billing_tier``, + ``anonymous_enabled``) and seeds their replacements for back-compat. + Returns: dict with settings if present and enabled, None otherwise """ @@ -206,9 +209,31 @@ def load_openrouter_integration_settings() -> dict | None: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) settings = data.get("openrouter_integration") - if settings and settings.get("enabled"): - return settings - return None + if not settings or not settings.get("enabled"): + return None + + if "billing_tier" in settings: + print( + "Warning: openrouter_integration.billing_tier is deprecated; " + "tier is now derived per model from OpenRouter data " + "(':free' suffix or zero pricing). Remove this key." + ) + + if "anonymous_enabled" in settings: + print( + "Warning: openrouter_integration.anonymous_enabled is " + "deprecated; use anonymous_enabled_paid and/or " + "anonymous_enabled_free instead. Both new flags have been " + "seeded from the legacy value for back-compat." + ) + settings.setdefault( + "anonymous_enabled_paid", settings["anonymous_enabled"] + ) + settings.setdefault( + "anonymous_enabled_free", settings["anonymous_enabled"] + ) + + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") return None @@ -217,9 +242,14 @@ def load_openrouter_integration_settings() -> dict | None: def initialize_openrouter_integration(): """ If enabled, fetch all OpenRouter models and append them to - config.GLOBAL_LLM_CONFIGS as dynamic premium entries. - Should be called BEFORE initialize_llm_router() so the router - correctly excludes premium models from Auto mode. + config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier`` + is derived per-model from OpenRouter's API signals (``:free`` suffix or + zero pricing), so free OpenRouter models correctly skip premium quota. + + Should be called BEFORE initialize_llm_router(). Dynamic entries are + tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used + by title-gen / sub-agent flows) remains scoped to curated YAML configs, + while user-facing Auto-mode thread pinning still considers them. """ settings = load_openrouter_integration_settings() if not settings: @@ -235,9 +265,15 @@ def initialize_openrouter_integration(): if new_configs: config.GLOBAL_LLM_CONFIGS.extend(new_configs) + free_count = sum( + 1 for c in new_configs if c.get("billing_tier") == "free" + ) + premium_count = sum( + 1 for c in new_configs if c.get("billing_tier") == "premium" + ) print( f"Info: OpenRouter integration added {len(new_configs)} models " - f"(billing_tier={settings.get('billing_tier', 'premium')})" + f"(free={free_count}, premium={premium_count})" ) else: print("Info: OpenRouter integration enabled but no models fetched") diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 9aca0f022..d62b4a4a5 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -245,31 +245,57 @@ global_llm_configs: # ============================================================================= # When enabled, dynamically fetches ALL available models from the OpenRouter API # and injects them as global configs. This gives premium users access to any model -# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota. +# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota, +# while free-tier OpenRouter models show up with a green Free badge and do NOT +# consume premium quota. # Models are fetched at startup and refreshed periodically in the background. # All calls go through LiteLLM with the openrouter/ prefix. openrouter_integration: enabled: false api_key: "sk-or-your-openrouter-api-key" - # billing_tier: "premium" or "free". Controls whether users need premium tokens. - billing_tier: "premium" - # anonymous_enabled: set true to also show OpenRouter models to no-login users - anonymous_enabled: false + + # Tier is derived PER MODEL from OpenRouter's own API signals: + # - id ends with ":free" -> billing_tier=free + # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free + # - otherwise -> billing_tier=premium + # No global billing_tier knob is honored; any legacy value emits a startup warning. + + # Anonymous access is split by tier so operators can expose only free + # models to no-login users without leaking paid inference. + anonymous_enabled_paid: false + anonymous_enabled_free: false + seo_enabled: false # quota_reserve_tokens: tokens reserved per call for quota enforcement quota_reserve_tokens: 4000 - # id_offset: starting negative ID for dynamically generated configs. - # Must not overlap with your static global_llm_configs IDs above. + # id_offset: base negative ID for dynamically generated configs. + # Model IDs are derived deterministically via BLAKE2b so they survive + # catalogue churn. Must not overlap with your static global_llm_configs IDs. id_offset: -10000 # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 - # rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing. - # OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled - # upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits). - # These values only matter if you set billing_tier to "free" (adding them to Auto mode). - # For premium-only models they are cosmetic. Set conservatively or match your account tier. + + # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router + # for per-deployment accounting when OR premium models participate in the + # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your + # real account limits live at https://openrouter.ai/settings/limits. rpm: 200 tpm: 1000000 + + # Rate limits for FREE OpenRouter models. Informational only: free OR + # models and openrouter/free are intentionally kept OUT of the LiteLLM + # Router pool, because OpenRouter enforces free-tier limits globally per + # account (~20 RPM + 50-1000 daily requests across every ":free" model + # combined) — per-deployment router accounting can't represent a shared + # bucket correctly. Free OR models stay fully available in the model + # selector and for user-facing Auto thread pinning. + free_rpm: 20 + free_tpm: 100000 + + # Expose openrouter/free as a single virtual "Free (Auto-Select)" entry. + # Recommended: keep true. OpenRouter picks a capable free model per request. + free_router_enabled: true + litellm_params: max_tokens: 16384 system_instructions: "" From 2019e90a04149cc491f0513d8c14f498792e2104 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:42:54 +0530 Subject: [PATCH 270/299] test(openrouter): cover pool filter, per-model tier, legacy config warnings --- .../services/test_llm_router_pool_filter.py | 215 ++++++++++++++++ .../test_openrouter_integration_service.py | 236 ++++++++++++++++++ .../services/test_openrouter_legacy_config.py | 110 ++++++++ 3 files changed, 561 insertions(+) create mode 100644 surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py create mode 100644 surfsense_backend/tests/unit/services/test_openrouter_integration_service.py create mode 100644 surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py new file mode 100644 index 000000000..0191025ec --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -0,0 +1,215 @@ +"""LLMRouterService pool-filter / rebuild tests. + +These tests focus on the *config plumbing* (which configs enter the router +pool, rebuild resets state correctly). They stub out the underlying +``litellm.Router`` so we don't need real API keys or network access. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from app.services.llm_router_service import LLMRouterService + +pytestmark = pytest.mark.unit + + +def _fake_yaml_config( + *, + id: int, + model_name: str, + billing_tier: str = "free", +) -> dict: + return { + "id": id, + "name": f"yaml-{id}", + "provider": "OPENAI", + "model_name": model_name, + "api_key": "sk-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 100, + "tpm": 100_000, + "litellm_params": {}, + } + + +def _fake_openrouter_config( + *, + id: int, + model_name: str, + billing_tier: str, + router_pool_eligible: bool | None = None, +) -> dict: + """Build a synthetic dynamic-OR config dict for router-pool tests. + + Defaults mirror Strategy 3: premium OR enters the pool, free OR stays + out. Callers can override ``router_pool_eligible`` to simulate legacy + configs or to regression-test the filter mechanics directly. + """ + if router_pool_eligible is None: + router_pool_eligible = billing_tier == "premium" + return { + "id": id, + "name": f"or-{id}", + "provider": "OPENROUTER", + "model_name": model_name, + "api_key": "sk-or-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 20 if billing_tier == "free" else 200, + "tpm": 100_000 if billing_tier == "free" else 1_000_000, + "litellm_params": {}, + "router_pool_eligible": router_pool_eligible, + } + + +def _reset_router_singleton() -> None: + instance = LLMRouterService.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + + +def test_router_pool_includes_or_premium_excludes_or_free(): + """Strategy 3: premium OR joins the pool, free OR stays out. + + Dynamic OpenRouter premium entries opt into load balancing alongside + curated YAML configs. Dynamic OR free entries are intentionally kept + out because OpenRouter's free tier enforces a single account-global + quota bucket that per-deployment router accounting can't represent. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ), + _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ), + ] + + with patch("app.services.llm_router_service.Router") as mock_router, patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb: + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + # YAML premium + YAML free + dynamic OR premium are all in the pool. + # Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced). + assert pool_models == { + "openai/gpt-4o", + "openai/gpt-4o-mini", + "openrouter/openai/gpt-4o", + } + + prem = LLMRouterService.get_instance()._premium_model_strings + # YAML premium is fingerprinted under both its model_string and its + # ``base_model`` form (existing behavior we don't want to regress). + assert "openai/gpt-4o" in prem + # Dynamic OR premium is now fingerprinted as premium so pool-level + # calls through the router are billed against premium quota. + assert "openrouter/openai/gpt-4o" in prem + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True + # Dynamic OR free never enters the pool, so it's never counted as premium. + assert LLMRouterService.is_premium_model( + "openrouter/meta-llama/llama-3.3-70b:free" + ) is False + + +def test_router_pool_filter_mechanics_respect_override(): + """The ``router_pool_eligible`` filter itself works independently of tier. + + Regression guard: if a future refactor ever sets the flag False on a + premium config (e.g. for maintenance), that config MUST be skipped by + ``initialize`` even though its tier is premium. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_openrouter_config( + id=-10_001, + model_name="openai/gpt-4o", + billing_tier="premium", + router_pool_eligible=False, # opt out despite being premium + ), + ] + + with patch("app.services.llm_router_service.Router") as mock_router, patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb: + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + assert pool_models == {"openai/gpt-4o"} + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False + + +def test_rebuild_refreshes_pool_after_configs_change(): + _reset_router_singleton() + configs_v1 = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + ] + configs_v2 = configs_v1 + [ + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + ] + + with patch("app.services.llm_router_service.Router") as mock_router, patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb: + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + + LLMRouterService.initialize(configs_v1) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``initialize`` should be a no-op here (already initialized). + LLMRouterService.initialize(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``rebuild`` must clear the guard and re-run with the new configs. + LLMRouterService.rebuild(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 2 + + +def test_auto_model_pin_candidates_include_dynamic_openrouter(): + """Dynamic OR configs must remain Auto-mode thread-pin candidates. + + Guards against a future regression where someone adds the + ``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``. + """ + from app.config import config + from app.services.auto_model_pin_service import _global_candidates + + or_premium = _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ) + or_free = _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ) + original = config.GLOBAL_LLM_CONFIGS + try: + config.GLOBAL_LLM_CONFIGS = [or_premium, or_free] + candidate_ids = {c["id"] for c in _global_candidates()} + assert candidate_ids == {-10_001, -10_002} + finally: + config.GLOBAL_LLM_CONFIGS = original diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py new file mode 100644 index 000000000..618edc23c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -0,0 +1,236 @@ +"""Unit tests for the dynamic OpenRouter integration.""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _FREE_ROUTER_ID, + _OPENROUTER_DYNAMIC_MARKER, + _build_free_router_config, + _generate_configs, + _openrouter_tier, + _stable_config_id, +) + +pytestmark = pytest.mark.unit + + +def _minimal_openrouter_model( + *, + model_id: str, + pricing: dict | None = None, + name: str | None = None, +) -> dict: + """Return a synthetic OpenRouter /api/v1/models entry. + + The real API payload includes a lot of fields; we only populate what + ``_generate_configs`` actually inspects (architecture, tool support, + context, pricing, id). + """ + return { + "id": model_id, + "name": name or model_id, + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"}, + } + + +# --------------------------------------------------------------------------- +# _openrouter_tier +# --------------------------------------------------------------------------- + + +def test_openrouter_tier_free_suffix(): + assert _openrouter_tier({"id": "foo/bar:free"}) == "free" + + +def test_openrouter_tier_zero_pricing(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0", "completion": "0"}, + } + assert _openrouter_tier(model) == "free" + + +def test_openrouter_tier_paid(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + } + assert _openrouter_tier(model) == "premium" + + +def test_openrouter_tier_missing_pricing_is_premium(): + assert _openrouter_tier({"id": "foo/bar"}) == "premium" + assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium" + + +# --------------------------------------------------------------------------- +# _stable_config_id +# --------------------------------------------------------------------------- + + +def test_stable_config_id_deterministic(): + taken1: set[int] = set() + taken2: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken1) + b = _stable_config_id("openai/gpt-4o", -10_000, taken2) + assert a == b + assert a < 0 + + +def test_stable_config_id_collision_decrements(): + """When two model_ids hash to the same slot, the second should decrement.""" + taken: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken) + # Force a collision by pre-populating ``taken`` with a slot we know will be + # picked. + taken_forced = {a} + b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced) + assert b != a + assert b == a - 1 + assert b in taken_forced + + +def test_stable_config_id_different_models_different_ids(): + taken: set[int] = set() + ids = { + _stable_config_id("openai/gpt-4o", -10_000, taken), + _stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken), + _stable_config_id("google/gemini-2.0-flash", -10_000, taken), + } + assert len(ids) == 3 + + +def test_stable_config_id_survives_catalogue_churn(): + """Removing a model should not shift other models' IDs (the bug we fix).""" + taken1: set[int] = set() + id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1) + _ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1) + id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1) + + taken2: set[int] = set() + id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2) + id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2) + + assert id_a1 == id_a2 + assert id_c1 == id_c2 + + +# --------------------------------------------------------------------------- +# _generate_configs +# --------------------------------------------------------------------------- + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, + "free_router_enabled": False, +} + + +def test_generate_configs_respects_tier(): + """Premium OR models opt into the router pool; free OR models stay out. + + Strategy-3 split: premium participates in LiteLLM Router load balancing, + free stays excluded because OpenRouter enforces a shared global free-tier + bucket that per-deployment router accounting can't represent. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="meta-llama/llama-3.3-70b-instruct:free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + paid = by_model["openai/gpt-4o"] + assert paid["billing_tier"] == "premium" + assert paid["rpm"] == 200 + assert paid["tpm"] == 1_000_000 + assert paid["anonymous_enabled"] is False + assert paid["router_pool_eligible"] is True + assert paid[_OPENROUTER_DYNAMIC_MARKER] is True + + free = by_model["meta-llama/llama-3.3-70b-instruct:free"] + assert free["billing_tier"] == "free" + assert free["rpm"] == 20 + assert free["tpm"] == 100_000 + assert free["anonymous_enabled"] is True + assert free["router_pool_eligible"] is False + + +def test_generate_configs_includes_free_router_when_enabled(): + raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] + settings = {**_SETTINGS_BASE, "free_router_enabled": True} + cfgs = _generate_configs(raw, settings) + free_router = next( + (c for c in cfgs if c["model_name"] == "openrouter/free"), None + ) + assert free_router is not None + assert free_router["id"] == _FREE_ROUTER_ID + assert free_router["billing_tier"] == "free" + assert free_router["router_pool_eligible"] is False + assert free_router["anonymous_enabled"] is True + + +def test_generate_configs_excludes_free_router_when_disabled(): + raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] + settings = {**_SETTINGS_BASE, "free_router_enabled": False} + cfgs = _generate_configs(raw, settings) + assert not any(c["model_name"] == "openrouter/free" for c in cfgs) + + +def test_generate_configs_excludes_free_router_without_api_key(): + """Without an API key the free-router entry is useless; skip it.""" + raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] + settings = {**_SETTINGS_BASE, "free_router_enabled": True, "api_key": ""} + cfgs = _generate_configs(raw, settings) + assert not any(c["model_name"] == "openrouter/free" for c in cfgs) + + +def test_generate_configs_drops_non_text_and_non_tool_models(): + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + { # image-output model + "id": "openai/dall-e", + "architecture": {"output_modalities": ["image"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + { # text but no tool calling + "id": "openai/completion-only", + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": [], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = [c["model_name"] for c in cfgs] + assert "openai/gpt-4o" in model_names + assert "openai/dall-e" not in model_names + assert "openai/completion-only" not in model_names + + +def test_build_free_router_config_shape(): + cfg = _build_free_router_config(dict(_SETTINGS_BASE)) + assert cfg["provider"] == "OPENROUTER" + assert cfg["model_name"] == "openrouter/free" + assert cfg["id"] == _FREE_ROUTER_ID + assert cfg["billing_tier"] == "free" + assert cfg["router_pool_eligible"] is False + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py new file mode 100644 index 000000000..b3dd2bf18 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py @@ -0,0 +1,110 @@ +"""Tests for deprecated-key warnings and back-compat in +``load_openrouter_integration_settings``. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +def _write_yaml(tmp_path: Path, body: str) -> Path: + cfg_dir = tmp_path / "app" / "config" + cfg_dir.mkdir(parents=True) + cfg_path = cfg_dir / "global_llm_config.yaml" + cfg_path.write_text(body, encoding="utf-8") + return cfg_path + + +def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + +def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + billing_tier: "premium" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert "billing_tier is deprecated" in captured + + +def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert settings["anonymous_enabled_paid"] is True + assert settings["anonymous_enabled_free"] is True + assert "anonymous_enabled is" in captured + assert "deprecated" in captured + + +def test_new_keys_take_priority_over_legacy_back_compat( + monkeypatch, tmp_path, capsys +): + """If both legacy and new keys are present, new keys win (setdefault).""" + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true + anonymous_enabled_paid: false + anonymous_enabled_free: false +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + capsys.readouterr() + assert settings is not None + assert settings["anonymous_enabled_paid"] is False + assert settings["anonymous_enabled_free"] is False + + +def test_disabled_integration_returns_none(monkeypatch, tmp_path): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: false + api_key: "sk-or-test" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + assert load_openrouter_integration_settings() is None From 4d34b56c4da4e3a935eaaa1b6cb6321597088802 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 18:09:50 +0530 Subject: [PATCH 271/299] docs(router): drop reference to virtual openrouter/free in is_premium_model --- surfsense_backend/app/services/llm_router_service.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index d624ff56c..060e01675 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -344,11 +344,11 @@ class LLMRouterService: Scope: only covers configs with ``router_pool_eligible`` truthy. That includes static YAML premium configs AND dynamic OpenRouter *premium* entries (which opt in at generation time). Dynamic OpenRouter *free* - entries and the virtual ``openrouter/free`` router are deliberately - kept out of the router pool — OpenRouter enforces free-tier limits - globally per account, so per-deployment router accounting can't - represent them correctly — and therefore return ``False`` here, which - matches their ``billing_tier="free"`` (no premium quota). + entries are deliberately kept out of the router pool — OpenRouter + enforces free-tier limits globally per account, so per-deployment + router accounting can't represent them correctly — and therefore + return ``False`` here, which matches their ``billing_tier="free"`` + (no premium quota). For per-request premium checks on an arbitrary config (static or dynamic, pool or non-pool), read ``agent_config.is_premium`` instead; From 680a1c1c38d090c54f790adbdf35e6beed5d7566 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 18:16:47 +0530 Subject: [PATCH 272/299] refactor(openrouter): remove virtual openrouter/free auto-select entry --- .../app/config/global_llm_config.example.yaml | 16 ++-- .../openrouter_integration_service.py | 78 +++++-------------- .../test_openrouter_integration_service.py | 56 +++++-------- 3 files changed, 45 insertions(+), 105 deletions(-) diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index d62b4a4a5..79cbe1e51 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -283,19 +283,15 @@ openrouter_integration: tpm: 1000000 # Rate limits for FREE OpenRouter models. Informational only: free OR - # models and openrouter/free are intentionally kept OUT of the LiteLLM - # Router pool, because OpenRouter enforces free-tier limits globally per - # account (~20 RPM + 50-1000 daily requests across every ":free" model - # combined) — per-deployment router accounting can't represent a shared - # bucket correctly. Free OR models stay fully available in the model - # selector and for user-facing Auto thread pinning. + # models are intentionally kept OUT of the LiteLLM Router pool, because + # OpenRouter enforces free-tier limits globally per account (~20 RPM + + # 50-1000 daily requests across every ":free" model combined) — + # per-deployment router accounting can't represent a shared bucket + # correctly. Free OR models stay fully available in the model selector + # and for user-facing Auto thread pinning. free_rpm: 20 free_tpm: 100000 - # Expose openrouter/free as a single virtual "Free (Auto-Select)" entry. - # Recommended: keep true. OpenRouter picks a capable free model per request. - free_router_enabled: true - litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 2d6a42337..06b7becdc 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -26,11 +26,6 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" # dynamic OpenRouter entries from hand-written YAML entries during refresh. _OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__" -# Fixed negative ID for the virtual ``openrouter/free`` auto-select entry. -# Chosen to sit far below any reasonable ``id_offset`` so it never collides -# with per-model stable IDs. -_FREE_ROUTER_ID = -9_999_999 - # Width of the hash space used by ``_stable_config_id``. 9_000_000 provides # enough headroom to avoid frequent collisions for OpenRouter's catalogue # (~300 models) while keeping IDs comfortably within Postgres INTEGER range. @@ -107,6 +102,11 @@ _EXCLUDED_MODEL_IDS: set[str] = { # Deep-research models reject standard params (temperature, etc.) "openai/o3-deep-research", "openai/o4-mini-deep-research", + # OpenRouter's own meta-router over free models. We already enumerate every + # concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread + # pinning handles churn via the repair path, so exposing an additional + # indirection layer would only duplicate the capability with an opaque slug. + "openrouter/free", } _EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) @@ -160,43 +160,6 @@ async def _fetch_models_async() -> list[dict] | None: return None -def _build_free_router_config(settings: dict[str, Any]) -> dict[str, Any]: - """Build the virtual ``openrouter/free`` auto-select config entry. - - This exposes OpenRouter's Free Models Router as a single selectable - option. LiteLLM forwards ``openrouter/openrouter/free`` and OpenRouter - picks a capable free model per request (availability varies, account-wide - rate limit is ~20 req/min). - """ - return { - "id": _FREE_ROUTER_ID, - "name": "OpenRouter Free (Auto-Select)", - "description": ( - "OpenRouter picks a capable free model per request. " - "~20 req/min account-wide; availability varies." - ), - "provider": "OPENROUTER", - "model_name": "openrouter/free", - "api_key": settings.get("api_key", ""), - "api_base": "", - "billing_tier": "free", - "rpm": settings.get("free_rpm", 20), - "tpm": settings.get("free_tpm", 100_000), - "anonymous_enabled": settings.get("anonymous_enabled_free", False), - "seo_enabled": False, - "seo_slug": None, - "quota_reserve_tokens": settings.get("quota_reserve_tokens", 4000), - "litellm_params": dict(settings.get("litellm_params") or {}), - "system_instructions": settings.get("system_instructions", ""), - "use_default_system_instructions": settings.get( - "use_default_system_instructions", True - ), - "citations_enabled": settings.get("citations_enabled", True), - "router_pool_eligible": False, - _OPENROUTER_DYNAMIC_MARKER: True, - } - - def _generate_configs( raw_models: list[dict], settings: dict[str, Any], @@ -213,13 +176,18 @@ def _generate_configs( - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``) so sub-agent ``model="auto"`` flows benefit from load balancing and failover across the curated YAML configs and the OR premium passthrough. - - Free OR models and the virtual ``openrouter/free`` entry stay excluded - (``router_pool_eligible=False``). LiteLLM Router tracks rate limits per - deployment, but OpenRouter enforces a single global free-tier quota - (~20 RPM + 50-1000 daily requests account-wide across every ``:free`` - model), so rotating across many free deployments would only burn the - shared bucket faster. Free OR models remain fully available for user- - facing Auto-mode thread pinning via ``auto_model_pin_service``. + - Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM + Router tracks rate limits per deployment, but OpenRouter enforces a + single global free-tier quota (~20 RPM + 50-1000 daily requests + account-wide across every ``:free`` model), so rotating across many + free deployments would only burn the shared bucket faster. Free OR + models remain fully available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + + OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream + via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer + because our own Auto (Fastest) pin + 24 h refresh + repair logic already + cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) api_key: str = settings.get("api_key", "") @@ -248,13 +216,7 @@ def _generate_configs( ] configs: list[dict] = [] - - if settings.get("free_router_enabled", True) and api_key: - configs.append(_build_free_router_config(settings)) - taken: set[int] = set() - if configs: - taken.add(_FREE_ROUTER_ID) for model in text_models: model_id: str = model["id"] @@ -382,9 +344,9 @@ class OpenRouterIntegrationService: ) # Rebuild the LiteLLM router so freshly fetched configs flow through - # (the router filters dynamic OR entries out of its pool, but a - # refresh still needs to pick up any static-config edits and reset - # cached context-window profiles). + # (dynamic OR premium entries now opt into the pool, free ones stay + # out; a refresh also needs to pick up any static-config edits and + # reset cached context-window profiles). try: from app.config import config as _app_config from app.services.llm_router_service import LLMRouterService diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 618edc23c..d3921729d 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -5,9 +5,7 @@ from __future__ import annotations import pytest from app.services.openrouter_integration_service import ( - _FREE_ROUTER_ID, _OPENROUTER_DYNAMIC_MARKER, - _build_free_router_config, _generate_configs, _openrouter_tier, _stable_config_id, @@ -135,7 +133,6 @@ _SETTINGS_BASE: dict = { "anonymous_enabled_paid": False, "anonymous_enabled_free": True, "quota_reserve_tokens": 4000, - "free_router_enabled": False, } @@ -172,33 +169,26 @@ def test_generate_configs_respects_tier(): assert free["router_pool_eligible"] is False -def test_generate_configs_includes_free_router_when_enabled(): - raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] - settings = {**_SETTINGS_BASE, "free_router_enabled": True} - cfgs = _generate_configs(raw, settings) - free_router = next( - (c for c in cfgs if c["model_name"] == "openrouter/free"), None - ) - assert free_router is not None - assert free_router["id"] == _FREE_ROUTER_ID - assert free_router["billing_tier"] == "free" - assert free_router["router_pool_eligible"] is False - assert free_router["anonymous_enabled"] is True +def test_generate_configs_excludes_upstream_openrouter_free_router(): + """OpenRouter's own ``openrouter/free`` meta-router must never become a card. - -def test_generate_configs_excludes_free_router_when_disabled(): - raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] - settings = {**_SETTINGS_BASE, "free_router_enabled": False} - cfgs = _generate_configs(raw, settings) - assert not any(c["model_name"] == "openrouter/free" for c in cfgs) - - -def test_generate_configs_excludes_free_router_without_api_key(): - """Without an API key the free-router entry is useless; skip it.""" - raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] - settings = {**_SETTINGS_BASE, "free_router_enabled": True, "api_key": ""} - cfgs = _generate_configs(raw, settings) - assert not any(c["model_name"] == "openrouter/free" for c in cfgs) + The upstream API returns this as a first-class zero-priced model, so + without an explicit blocklist entry it would slip through every other + filter (text output, tool calling, 200k context, non-Amazon) and land + in the selector as a duplicate of the concrete ``:free`` cards. The + exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="openrouter/free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openrouter/free" not in model_names + assert "openai/gpt-4o" in model_names def test_generate_configs_drops_non_text_and_non_tool_models(): @@ -226,11 +216,3 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/completion-only" not in model_names -def test_build_free_router_config_shape(): - cfg = _build_free_router_config(dict(_SETTINGS_BASE)) - assert cfg["provider"] == "OPENROUTER" - assert cfg["model_name"] == "openrouter/free" - assert cfg["id"] == _FREE_ROUTER_ID - assert cfg["billing_tier"] == "free" - assert cfg["router_pool_eligible"] is False - assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True From 1863f2832b203d101a159653ff2198d59b93ddfc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 18:43:45 +0530 Subject: [PATCH 273/299] fix(LayoutShell): add 'isolate' class to main content panel --- surfsense_web/components/layout/ui/shell/LayoutShell.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_web/components/layout/ui/shell/LayoutShell.tsx b/surfsense_web/components/layout/ui/shell/LayoutShell.tsx index d41dd9e6d..207d27f7b 100644 --- a/surfsense_web/components/layout/ui/shell/LayoutShell.tsx +++ b/surfsense_web/components/layout/ui/shell/LayoutShell.tsx @@ -132,7 +132,7 @@ function MainContentPanel({ const isDocumentTab = activeTab?.type === "document"; return ( - <div className="relative flex flex-1 flex-col min-w-0"> + <div className="relative isolate flex flex-1 flex-col min-w-0"> <TabBar onTabSwitch={onTabSwitch} onNewChat={onNewChat} From 421a4d7d0807f17da7c29290e96e96ff73cc5e72 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 19:32:42 +0530 Subject: [PATCH 274/299] refactor(auto_model_pin): simplify thread-level pinning by removing unused fields and indexes --- ...38_add_thread_auto_model_pinning_fields.py | 31 +++++----------- surfsense_backend/app/db.py | 13 +++---- .../app/routes/search_spaces_routes.py | 6 +-- .../app/services/auto_model_pin_service.py | 37 ++++++++----------- .../services/test_auto_model_pin_service.py | 28 +++----------- 5 files changed, 37 insertions(+), 78 deletions(-) diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 3972b84b9..fba621a0c 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,10 +4,12 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add thread-level fields to persist Auto (Fastest) model pinning metadata: -- pinned_llm_config_id: concrete resolved config id used for this thread -- pinned_auto_mode: auto policy identifier (currently "auto_fastest") -- pinned_at: timestamp when the pin was created/refreshed +Add a single thread-level column to persist the Auto (Fastest) model pin: +- pinned_llm_config_id: concrete resolved global LLM config id used for this + thread. NULL means "no pin; Auto will resolve on next turn". + +The column is unindexed: all reads are by new_chat_threads.id (primary key), +so a secondary index would be dead write amplification. """ from __future__ import annotations @@ -27,29 +29,14 @@ def upgrade() -> None: "ALTER TABLE new_chat_threads " "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" ) - op.execute( - "ALTER TABLE new_chat_threads " - "ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)" - ) - op.execute( - "ALTER TABLE new_chat_threads " - "ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE" - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id " - "ON new_chat_threads (pinned_llm_config_id)" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode " - "ON new_chat_threads (pinned_auto_mode)" - ) def downgrade() -> None: + # Drop any shape the thread row may be carrying. The extra columns and + # indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS + # makes each statement a safe no-op on the lean shape. op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode") op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id") - op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at") op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode") op.execute( diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index ca3334f8b..2fe478d9b 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,13 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto model pinning metadata: - # - pinned_llm_config_id stores the concrete resolved model config id. - # - pinned_auto_mode indicates which auto policy produced the pin. - # This allows Auto (Fastest) to resolve once per thread and stay stable. - pinned_llm_config_id = Column(Integer, nullable=True, index=True) - pinned_auto_mode = Column(String(32), nullable=True, index=True) - pinned_at = Column(TIMESTAMP(timezone=True), nullable=True) + # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # config id. NULL means no pin; Auto will resolve on the next turn. + # Single-writer invariant: only app.services.auto_model_pin_service sets + # or clears this column (plus bulk clears when a search space's + # agent_llm_id changes). Unindexed: all reads are by primary key. + pinned_llm_config_id = Column(Integer, nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7944e7d66..72715ea5b 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -803,11 +803,7 @@ async def update_llm_preferences( await session.execute( update(NewChatThread) .where(NewChatThread.search_space_id == search_space_id) - .values( - pinned_llm_config_id=None, - pinned_auto_mode=None, - pinned_at=None, - ) + .values(pinned_llm_config_id=None) ) logger.info( "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 6b69c91ea..1a2061492 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -2,8 +2,14 @@ Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we resolve that virtual mode to one concrete global LLM config exactly once and -persist the chosen config id on ``new_chat_threads`` so subsequent turns are -stable. +persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so +subsequent turns are stable. + +Single-writer invariant: this module is the only writer of +``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in +``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +Therefore a non-NULL value unambiguously means "this thread has an +Auto-resolved pin"; no separate source/policy column is needed. """ from __future__ import annotations @@ -11,7 +17,6 @@ from __future__ import annotations import hashlib import logging from dataclasses import dataclass -from datetime import UTC, datetime from uuid import UUID from sqlalchemy import select @@ -90,10 +95,10 @@ async def resolve_or_get_pinned_llm_config_id( selected_llm_config_id: int, force_repin_free: bool = False, ) -> AutoPinResolution: - """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + """Resolve Auto (Fastest) to one concrete config id and persist the pin. - For non-auto selections, this function clears existing auto pin metadata and - returns the selected id as-is. + For non-auto selections, this function clears any existing pin and returns + the selected id as-is. """ thread = ( ( @@ -113,16 +118,10 @@ async def resolve_or_get_pinned_llm_config_id( f"Thread {thread_id} does not belong to search space {search_space_id}" ) - # Explicit model selected: clear stale auto pin metadata. + # Explicit model selected: clear any stale pin. if selected_llm_config_id != AUTO_FASTEST_ID: - if ( - thread.pinned_llm_config_id is not None - or thread.pinned_auto_mode is not None - or thread.pinned_at is not None - ): + if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None - thread.pinned_auto_mode = None - thread.pinned_at = None await session.commit() return AutoPinResolution( resolved_llm_config_id=selected_llm_config_id, @@ -135,12 +134,11 @@ async def resolve_or_get_pinned_llm_config_id( raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} - # Reuse existing valid pin without re-checking current quota (no silent tier switch), - # unless the caller explicitly requests a forced repin to free. + # Reuse an existing valid pin without re-checking current quota (no silent + # tier switch), unless the caller explicitly requests a forced repin to free. pinned_id = thread.pinned_llm_config_id if ( not force_repin_free - and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id ): @@ -159,11 +157,10 @@ async def resolve_or_get_pinned_llm_config_id( ) if pinned_id is not None: logger.info( - "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, search_space_id, pinned_id, - thread.pinned_auto_mode, ) premium_eligible = ( @@ -184,8 +181,6 @@ async def resolve_or_get_pinned_llm_config_id( selected_tier = _tier_of(selected_cfg) thread.pinned_llm_config_id = selected_id - thread.pinned_auto_mode = AUTO_FASTEST_MODE - thread.pinned_at = datetime.now(UTC) await session.commit() if force_repin_free: diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 0a2342e05..2094ea6dd 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,7 +6,6 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( - AUTO_FASTEST_MODE, resolve_or_get_pinned_llm_config_id, ) @@ -45,14 +44,11 @@ def _thread( *, search_space_id: int = 10, pinned_llm_config_id: int | None = None, - pinned_auto_mode: str | None = None, ): return SimpleNamespace( id=1, search_space_id=search_space_id, pinned_llm_config_id=pinned_llm_config_id, - pinned_auto_mode=pinned_auto_mode, - pinned_at=None, ) @@ -93,8 +89,6 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): ) assert result.resolved_llm_config_id in {-1, -2} assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id - assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE - assert session.thread.pinned_at is not None assert session.commit_count == 1 @@ -102,9 +96,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -228,9 +220,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -275,9 +265,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -325,9 +313,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-2)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -345,8 +331,6 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): ) assert result.resolved_llm_config_id == 7 assert session.thread.pinned_llm_config_id is None - assert session.thread.pinned_auto_mode is None - assert session.thread.pinned_at is None assert session.commit_count == 1 @@ -354,9 +338,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-999)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", From d9058b73f5306f6dc40ba553cec92cf659246d1a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:37:49 +0530 Subject: [PATCH 275/299] feat(auto_pin): add pure-function quality scoring module --- .../app/services/quality_score.py | 382 ++++++++++++++++++ .../tests/unit/services/test_quality_score.py | 342 ++++++++++++++++ 2 files changed, 724 insertions(+) create mode 100644 surfsense_backend/app/services/quality_score.py create mode 100644 surfsense_backend/tests/unit/services/test_quality_score.py diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py new file mode 100644 index 000000000..8f6c75d56 --- /dev/null +++ b/surfsense_backend/app/services/quality_score.py @@ -0,0 +1,382 @@ +"""Pure-function quality scoring for Auto (Fastest) model selection. + +This module is import-free of any service / request-path dependencies. All +numbers are computed once during the OpenRouter refresh tick (or YAML load) +and cached on the cfg dict, so the chat hot path only does a precomputed +sort and a SHA256 pick. + +Score components (0-100 scale, higher is better): + +* ``static_score_or`` – derived from the bulk ``/api/v1/models`` payload + (provider prestige + ``created`` recency + pricing band + context window + + capabilities + narrow tiny/legacy slug penalty). +* ``static_score_yaml`` – same shape for hand-curated YAML configs, plus + an operator-trust bonus (the operator deliberately picked this model). +* ``aggregate_health`` – run on per-model ``/api/v1/models/{id}/endpoints`` + responses; returns ``(gated, score_or_none)``. + +The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in +:mod:`app.services.openrouter_integration_service` because that's the only +caller that sees both halves. +""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# Tunables (constants, not flags) +# --------------------------------------------------------------------------- + +# Top-K size for deterministic spread inside the locked tier. +_QUALITY_TOP_K: int = 5 + +# Hard health gate: any cfg whose best non-null uptime is below this % +# is excluded from Auto-mode selection entirely. +_HEALTH_GATE_UPTIME_PCT: float = 90.0 + +# Health/static blend weight when a cfg has fresh /endpoints data. +_HEALTH_BLEND_WEIGHT: float = 0.5 + +# Static bonus applied to YAML cfgs because the operator hand-picked them. +_OPERATOR_TRUST_BONUS: int = 20 + +# /endpoints fan-out is bounded per refresh tick. +_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50 +_HEALTH_ENRICH_TOP_N_FREE: int = 30 +_HEALTH_ENRICH_CONCURRENCY: int = 15 +_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0 + +# If at least this fraction of /endpoints fetches fail in a refresh cycle, +# fall back to the previous cycle's last-good cache instead of writing +# partial / stale health values. +_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25 + +# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise +# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with +# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and +# blanket-penalising them suppresses high-quality picks. +_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = ( + "-1b-", + "-1.2b-", + "-1.5b-", + "-2b-", + "-3b-", + "gemma-3n", + "lfm-", + "-base", + "-distill", + ":nitro", + "-preview", +) + + +# --------------------------------------------------------------------------- +# Provider prestige tables +# --------------------------------------------------------------------------- + +# OpenRouter-side provider slug (the prefix before ``/`` in the model id). +# Tiers are coarse: frontier labs > strong open / fast-moving labs > +# specialist labs > everything else. +PROVIDER_PRESTIGE_OR: dict[str, int] = { + # Frontier labs + "openai": 50, + "anthropic": 50, + "google": 50, + "x-ai": 50, + # Strong open / fast-moving labs + "deepseek": 38, + "qwen": 38, + "meta-llama": 38, + "mistralai": 38, + "cohere": 38, + "nvidia": 38, + "alibaba": 38, + # Specialist / regional / strong second-tier + "microsoft": 28, + "01-ai": 28, + "minimax": 28, + "moonshot": 28, + "z-ai": 28, + "nousresearch": 28, + "ai21": 28, + "perplexity": 28, + # Smaller / niche providers + "liquid": 18, + "cognitivecomputations": 18, + "venice": 18, + "inflection": 18, +} + +# YAML provider field (the upstream API shape the operator selected). +PROVIDER_PRESTIGE_YAML: dict[str, int] = { + "AZURE_OPENAI": 50, + "OPENAI": 50, + "ANTHROPIC": 50, + "GOOGLE": 50, + "VERTEX_AI": 50, + "GEMINI": 50, + "XAI": 50, + "MISTRAL": 38, + "DEEPSEEK": 38, + "COHERE": 38, + "GROQ": 30, + "TOGETHER_AI": 28, + "FIREWORKS_AI": 28, + "PERPLEXITY": 28, + "MINIMAX": 28, + "BEDROCK": 28, + "OPENROUTER": 25, + "OLLAMA": 12, + "CUSTOM": 12, +} + + +# --------------------------------------------------------------------------- +# Pure scoring helpers +# --------------------------------------------------------------------------- + +# Calibrated against the live /api/v1/models bulk dump. Frontier models +# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5, +# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band; +# anything older trails off. +_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = ( + (60, 20), + (180, 16), + (365, 12), + (540, 9), + (730, 6), + (1095, 3), +) + + +def created_recency_signal(created_ts: int | None, now_ts: int) -> int: + """Return 0-20 based on how recently the model was published. + + Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for + YAML cfgs). Models without a usable timestamp get 0 (we don't penalise, + we just don't reward). + """ + if created_ts is None or created_ts <= 0 or now_ts <= 0: + return 0 + age_days = max(0, (now_ts - int(created_ts)) // 86_400) + for cutoff, score in _RECENCY_BANDS_DAYS: + if age_days <= cutoff: + return score + return 0 + + +def pricing_band( + prompt: str | float | int | None, + completion: str | float | int | None, +) -> int: + """Return 0-15 based on combined prompt+completion cost per 1M tokens. + + Higher-priced models tend to be the larger / more capable ones. A free + model returns 0 (we use other signals to rank free-vs-free instead). + Uncoercible inputs are treated as 0 rather than raising. + """ + + def _to_float(value) -> float: + if value is None: + return 0.0 + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + p = _to_float(prompt) + c = _to_float(completion) + total_per_million = (p + c) * 1_000_000 + + if total_per_million >= 20.0: + return 15 + if total_per_million >= 5.0: + return 12 + if total_per_million >= 1.0: + return 9 + if total_per_million >= 0.3: + return 6 + if total_per_million >= 0.05: + return 4 + if total_per_million > 0.0: + return 2 + return 0 + + +def context_signal(ctx: int | None) -> int: + """Return 0-10 based on the model's context window.""" + if not ctx or ctx <= 0: + return 0 + if ctx >= 1_000_000: + return 10 + if ctx >= 400_000: + return 8 + if ctx >= 200_000: + return 6 + if ctx >= 128_000: + return 4 + if ctx >= 100_000: + return 2 + return 0 + + +def capabilities_signal(supported_parameters: list[str] | None) -> int: + """Return 0-5 for capabilities that matter for our agent flows.""" + if not supported_parameters: + return 0 + params = set(supported_parameters) + score = 0 + if "tools" in params: + score += 2 + if "structured_outputs" in params or "response_format" in params: + score += 2 + if "reasoning" in params or "include_reasoning" in params: + score += 1 + return min(score, 5) + + +def slug_penalty(model_id: str) -> int: + """Return a non-positive number; matches the narrow tiny/legacy patterns.""" + if not model_id: + return 0 + needle = model_id.lower() + for pattern in _TINY_LEGACY_PENALTY_PATTERNS: + if pattern in needle: + return -10 + return 0 + + +def _provider_prestige_or(model_id: str) -> int: + if "/" not in model_id: + return 0 + slug = model_id.split("/", 1)[0].lower() + return PROVIDER_PRESTIGE_OR.get(slug, 15) + + +def static_score_or(or_model: dict, *, now_ts: int) -> int: + """Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale.""" + model_id = str(or_model.get("id", "")) + pricing = or_model.get("pricing") or {} + + score = ( + _provider_prestige_or(model_id) + + created_recency_signal(or_model.get("created"), now_ts) + + pricing_band(pricing.get("prompt"), pricing.get("completion")) + + context_signal(or_model.get("context_length")) + + capabilities_signal(or_model.get("supported_parameters")) + + slug_penalty(model_id) + ) + return max(0, min(100, int(score))) + + +def static_score_yaml(cfg: dict) -> int: + """Score a YAML-curated cfg on a 0-100 scale. + + Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately + listed this model. Pricing / context fall through to lazy ``litellm`` + lookups; failures are silent (we just lose those sub-points). + """ + provider = str(cfg.get("provider", "")).upper() + base = PROVIDER_PRESTIGE_YAML.get(provider, 15) + + model_name = cfg.get("model_name") or "" + litellm_params = cfg.get("litellm_params") or {} + lookup_name = ( + litellm_params.get("base_model") + or litellm_params.get("model") + or model_name + ) + + ctx = 0 + p_cost: float = 0.0 + c_cost: float = 0.0 + try: + from litellm import get_model_info # lazy: avoid cold-import cost + + info = get_model_info(lookup_name) or {} + ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0) + p_cost = float(info.get("input_cost_per_token") or 0.0) + c_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception: + # Unknown to litellm — that's fine for prestige+operator-bonus weighting. + pass + + score = ( + base + + _OPERATOR_TRUST_BONUS + + pricing_band(p_cost, c_cost) + + context_signal(ctx) + + slug_penalty(str(model_name)) + ) + return max(0, min(100, int(score))) + + +# --------------------------------------------------------------------------- +# Health aggregation +# --------------------------------------------------------------------------- + + +def _coerce_pct(value) -> float | None: + try: + if value is None: + return None + f = float(value) + except (TypeError, ValueError): + return None + if f < 0: + return None + # OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it + # as a 0-100 percentage. Normalise. + return f * 100.0 if f <= 1.0 else f + + +def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]: + """Pick the best (highest) non-null uptime across all endpoints. + + Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` > + ``uptime_last_5m``. Returns ``(uptime_pct, window_used)``. + """ + for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + return max(values), window + return None, None + + +def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]: + """Aggregate a model's per-endpoint health into ``(gated, score_or_none)``. + + Hard gate (returns ``(True, None)``): + * ``endpoints`` empty, + * no endpoint reports ``status == 0`` (OK), or + * best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``. + + On a pass, returns a 0-100 health score blending uptime, status, and a + freshness-weighted recent uptime sample. + """ + if not endpoints: + return True, None + + any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints) + if not any_ok: + return True, None + + best_uptime, _ = _best_uptime(endpoints) + if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT: + return True, None + + # Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing. + freshness = None + for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + freshness = max(values) + break + + uptime_term = best_uptime + status_term = 100.0 if any_ok else 0.0 + freshness_term = freshness if freshness is not None else best_uptime + + score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term + return False, max(0.0, min(100.0, score)) diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py new file mode 100644 index 000000000..fbc91521d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -0,0 +1,342 @@ +"""Unit tests for the Auto (Fastest) quality scoring module.""" + +from __future__ import annotations + +import time + +import pytest + +from app.services.quality_score import ( + _HEALTH_GATE_UPTIME_PCT, + _OPERATOR_TRUST_BONUS, + aggregate_health, + capabilities_signal, + context_signal, + created_recency_signal, + pricing_band, + slug_penalty, + static_score_or, + static_score_yaml, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# created_recency_signal +# --------------------------------------------------------------------------- + + +def test_created_recency_signal_recent_model_scores_high(): + now = 1_750_000_000 # ~mid-2025 + one_month_ago = now - (30 * 86_400) + assert created_recency_signal(one_month_ago, now) == 20 + + +def test_created_recency_signal_old_model_scores_zero(): + now = 1_750_000_000 + five_years_ago = now - (5 * 365 * 86_400) + assert created_recency_signal(five_years_ago, now) == 0 + + +def test_created_recency_signal_missing_timestamp_is_neutral(): + now = 1_750_000_000 + assert created_recency_signal(None, now) == 0 + assert created_recency_signal(0, now) == 0 + + +def test_created_recency_signal_monotonic_decay(): + now = 1_750_000_000 + scores = [ + created_recency_signal(now - days * 86_400, now) + for days in (30, 120, 300, 500, 700, 1000, 1500) + ] + assert scores == sorted(scores, reverse=True) + + +# --------------------------------------------------------------------------- +# pricing_band +# --------------------------------------------------------------------------- + + +def test_pricing_band_free_returns_zero(): + assert pricing_band("0", "0") == 0 + assert pricing_band(0.0, 0.0) == 0 + assert pricing_band(None, None) == 0 + + +def test_pricing_band_handles_unparseable(): + assert pricing_band("not-a-number", "0") == 0 + assert pricing_band({}, []) == 0 # type: ignore[arg-type] + + +def test_pricing_band_premium_tiers_increase_with_price(): + cheap = pricing_band("0.0000003", "0.0000005") + mid = pricing_band("0.000003", "0.000015") + flagship = pricing_band("0.00001", "0.00005") + assert 0 < cheap < mid < flagship + + +# --------------------------------------------------------------------------- +# context_signal +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "ctx,expected", + [ + (1_500_000, 10), + (1_000_000, 10), + (500_000, 8), + (200_000, 6), + (128_000, 4), + (100_000, 2), + (50_000, 0), + (0, 0), + (None, 0), + ], +) +def test_context_signal_bands(ctx, expected): + assert context_signal(ctx) == expected + + +# --------------------------------------------------------------------------- +# capabilities_signal +# --------------------------------------------------------------------------- + + +def test_capabilities_signal_caps_at_five(): + assert capabilities_signal( + ["tools", "structured_outputs", "reasoning", "include_reasoning"] + ) <= 5 + + +def test_capabilities_signal_tools_only(): + assert capabilities_signal(["tools"]) == 2 + + +def test_capabilities_signal_empty(): + assert capabilities_signal(None) == 0 + assert capabilities_signal([]) == 0 + + +# --------------------------------------------------------------------------- +# slug_penalty +# --------------------------------------------------------------------------- + + +def test_slug_penalty_demotes_tiny_models(): + assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0 + assert slug_penalty("liquid/lfm-7b") < 0 + assert slug_penalty("google/gemma-3n-e4b-it") < 0 + + +def test_slug_penalty_skips_capable_mini_nano_lite_models(): + """Critical Option C+ regression: don't penalise modern frontier + models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.).""" + assert slug_penalty("openai/gpt-5-mini") == 0 + assert slug_penalty("openai/gpt-5-nano") == 0 + assert slug_penalty("google/gemini-2.5-flash-lite") == 0 + assert slug_penalty("anthropic/claude-haiku-4.5") == 0 + + +def test_slug_penalty_demotes_legacy_variants(): + assert slug_penalty("openai/o1-preview") < 0 + assert slug_penalty("foo/bar-base") < 0 + assert slug_penalty("foo/bar-distill") < 0 + + +def test_slug_penalty_empty_input(): + assert slug_penalty("") == 0 + + +# --------------------------------------------------------------------------- +# static_score_or +# --------------------------------------------------------------------------- + + +def _or_model( + *, + model_id: str, + created: int | None = None, + prompt: str = "0.000003", + completion: str = "0.000015", + context: int = 200_000, + params: list[str] | None = None, +) -> dict: + return { + "id": model_id, + "created": created, + "pricing": {"prompt": prompt, "completion": completion}, + "context_length": context, + "supported_parameters": params if params is not None else ["tools"], + } + + +def test_static_score_or_frontier_premium_beats_free_tiny(): + now = 1_750_000_000 + frontier = _or_model( + model_id="openai/gpt-5", + created=now - (60 * 86_400), + prompt="0.000005", + completion="0.000020", + context=400_000, + params=["tools", "structured_outputs", "reasoning"], + ) + tiny_free = _or_model( + model_id="meta-llama/llama-3.2-1b-instruct:free", + created=now - (5 * 365 * 86_400), + prompt="0", + completion="0", + context=128_000, + params=["tools"], + ) + assert static_score_or(frontier, now_ts=now) > static_score_or( + tiny_free, now_ts=now + ) + + +def test_static_score_or_score_is_clamped_0_to_100(): + now = int(time.time()) + score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now) + assert 0 <= score <= 100 + + +def test_static_score_or_unknown_provider_is_neutral_not_zero(): + now = int(time.time()) + score = static_score_or( + _or_model(model_id="some-new-lab/some-model"), + now_ts=now, + ) + assert score > 0 + + +def test_static_score_or_recent_release_beats_year_old_same_provider(): + now = 1_750_000_000 + fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400)) + old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400)) + assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now) + + +# --------------------------------------------------------------------------- +# static_score_yaml +# --------------------------------------------------------------------------- + + +def test_static_score_yaml_includes_operator_bonus(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_unknown_provider_still_carries_bonus(): + cfg = { + "provider": "SOME_NEW_PROVIDER", + "model_name": "weird-model", + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_clamped_0_to_100(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + assert 0 <= static_score_yaml(cfg) <= 100 + + +# --------------------------------------------------------------------------- +# aggregate_health +# --------------------------------------------------------------------------- + + +def test_aggregate_health_gates_when_uptime_below_threshold(): + """Live data showed Venice-routed cfgs at 53-68%; this guards that the + 90% gate excludes them.""" + venice_endpoints = [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.60, + "uptime_last_5m": 0.50, + }, + { + "status": 0, + "uptime_last_30m": 0.65, + "uptime_last_1d": 0.68, + "uptime_last_5m": 0.62, + }, + ] + gated, score = aggregate_health(venice_endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_passes_for_healthy_provider(): + healthy = [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + }, + ] + gated, score = aggregate_health(healthy) + assert gated is False + assert score is not None + assert score >= _HEALTH_GATE_UPTIME_PCT + + +def test_aggregate_health_picks_best_endpoint_across_multiple(): + """Multi-endpoint aggregation should reward the best non-null uptime.""" + mixed = [ + {"status": 0, "uptime_last_30m": 0.55}, + {"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate + ] + gated, score = aggregate_health(mixed) + assert gated is False + assert score is not None + + +def test_aggregate_health_empty_endpoints_gated(): + gated, score = aggregate_health([]) + assert gated is True + assert score is None + + +def test_aggregate_health_no_status_zero_gated(): + """Even with high uptime, no OK status means the cfg is broken upstream.""" + endpoints = [ + {"status": 1, "uptime_last_30m": 0.99}, + {"status": 2, "uptime_last_30m": 0.98}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_all_uptime_null_gated(): + endpoints = [ + {"status": 0, "uptime_last_30m": None, "uptime_last_1d": None}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_pct_normalisation(): + """OpenRouter returns 0-1 fractions; some endpoints surface 0-100% + percentages. Both should reach the same gate decision.""" + fraction_form = [{"status": 0, "uptime_last_30m": 0.95}] + pct_form = [{"status": 0, "uptime_last_30m": 95.0}] + g1, s1 = aggregate_health(fraction_form) + g2, s2 = aggregate_health(pct_form) + assert g1 == g2 == False # noqa: E712 + assert s1 is not None and s2 is not None + assert abs(s1 - s2) < 0.5 From c229b4356ac7112576e98397b5eb304b3ca8eefa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:38:21 +0530 Subject: [PATCH 276/299] feat(config): stamp Auto (Fastest) ranking metadata on YAML configs --- surfsense_backend/app/config/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 11cbe24a7..b3eff571e 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -63,6 +63,27 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) + # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Tier A — operator-curated, locked first when premium-eligible. + # The OpenRouter refresh tick later re-stamps health for any cfg + # whose provider == "OPENROUTER" via _enrich_health. + try: + from app.services.quality_score import static_score_yaml + + for cfg in configs: + cfg["auto_pin_tier"] = "A" + static_q = static_score_yaml(cfg) + cfg["quality_score_static"] = static_q + cfg["quality_score"] = static_q + cfg["quality_score_health"] = None + # YAML cfgs whose provider is OPENROUTER are also subject + # to health gating against their own /endpoints data — a + # hand-picked dead OR model is still dead. _enrich_health + # re-stamps health_gated for them on the next refresh tick. + cfg["health_gated"] = False + except Exception as e: + print(f"Warning: Failed to score global LLM configs: {e}") + return configs except Exception as e: print(f"Warning: Failed to load global LLM configs: {e}") From 1eedcaa55178134ce9c7f45c11707a7406bdb291 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:38:40 +0530 Subject: [PATCH 277/299] feat(openrouter): blend per-model /endpoints health into quality score --- .../openrouter_integration_service.py | 231 ++++++++++++ .../services/test_or_health_enrichment.py | 331 ++++++++++++++++++ 2 files changed, 562 insertions(+) create mode 100644 surfsense_backend/tests/unit/services/test_or_health_enrichment.py diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 06b7becdc..9c3eaa5ea 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -14,13 +14,28 @@ import asyncio import hashlib import logging import threading +import time from typing import Any import httpx +from app.services.quality_score import ( + _HEALTH_BLEND_WEIGHT, + _HEALTH_ENRICH_CONCURRENCY, + _HEALTH_ENRICH_TOP_N_FREE, + _HEALTH_ENRICH_TOP_N_PREMIUM, + _HEALTH_FAIL_RATIO_FALLBACK, + _HEALTH_FETCH_TIMEOUT_SEC, + aggregate_health, + static_score_or, +) + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +OPENROUTER_ENDPOINTS_URL_TEMPLATE = ( + "https://openrouter.ai/api/v1/models/{model_id}/endpoints" +) # Sentinel value stored on each generated config so we can distinguish # dynamic OpenRouter entries from hand-written YAML entries during refresh. @@ -217,12 +232,15 @@ def _generate_configs( configs: list[dict] = [] taken: set[int] = set() + now_ts = int(time.time()) for model in text_models: model_id: str = model["id"] name: str = model.get("name", model_id) tier = _openrouter_tier(model) + static_q = static_score_or(model, now_ts=now_ts) + cfg: dict[str, Any] = { "id": _stable_config_id(model_id, id_offset, taken), "name": name, @@ -249,6 +267,15 @@ def _generate_configs( # there — it just drains the shared bucket faster. "router_pool_eligible": tier == "premium", _OPENROUTER_DYNAMIC_MARKER: True, + # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # to the static score and gets re-blended with health on the next + # ``_enrich_health`` pass (synchronous on refresh, deferred on cold + # start so startup latency is unchanged). + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_q, + "quality_score_health": None, + "quality_score": static_q, + "health_gated": False, } configs.append(cfg) @@ -267,6 +294,12 @@ class OpenRouterIntegrationService: self._configs_by_id: dict[int, dict] = {} self._initialized = False self._refresh_task: asyncio.Task | None = None + # Last-good per-model health snapshot. Survives across refresh + # cycles so a transient OpenRouter /endpoints outage doesn't drop + # every cfg back to static-only scoring. + # Shape: {model_name: {"gated": bool, "score": float | None}} + self._health_cache: dict[str, dict[str, Any]] = {} + self._enrich_task: asyncio.Task | None = None @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -307,6 +340,20 @@ class OpenRouterIntegrationService: tier_counts["free"], tier_counts["premium"], ) + + # Schedule the first health-enrichment pass as a deferred task so + # cold-start latency is unchanged. Only valid when an event loop is + # already running (e.g. FastAPI lifespan); Celery worker init is + # fully sync so we silently skip — its first refresh tick (or the + # next refresh from the web process) will populate health data. + try: + loop = asyncio.get_running_loop() + self._enrich_task = loop.create_task( + self._enrich_health_safely(self._configs) + ) + except RuntimeError: + pass + return self._configs # ------------------------------------------------------------------ @@ -343,6 +390,13 @@ class OpenRouterIntegrationService: tier_counts["premium"], ) + # Re-blend health scores against the freshly fetched catalogue. Also + # re-stamps health for any YAML-curated cfg with provider==OPENROUTER + # so a hand-picked dead OR model is gated like a dynamic one. + await self._enrich_health_safely( + static_configs + new_configs, log_summary=True + ) + # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay # out; a refresh also needs to pick up any static-config edits and @@ -373,6 +427,183 @@ class OpenRouterIntegrationService: counts[tier] += 1 return counts + # ------------------------------------------------------------------ + # Auto (Fastest) health enrichment + # ------------------------------------------------------------------ + + async def _enrich_health_safely( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Wrapper around ``_enrich_health`` that swallows all errors. + + Health enrichment is best-effort: any failure must leave cfgs in + their static-only state and never break refresh / startup. + """ + try: + await self._enrich_health(configs, log_summary=log_summary) + except Exception: + logger.exception("OpenRouter health enrichment failed") + + async def _enrich_health( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Fetch per-model ``/endpoints`` data for the top OR cfgs and blend + the resulting health score into ``cfg["quality_score"]``. + + Bounded fan-out: top-N per tier by ``quality_score_static`` only, + with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the + outbound HTTP. Misses fall back to a per-model last-good cache; if + the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep + the entire previous cycle's cache for this run. + """ + or_cfgs = [ + c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" + ] + if not or_cfgs: + return + + premium_pool = sorted( + [ + c + for c in or_cfgs + if str(c.get("billing_tier", "")).lower() == "premium" + ], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_PREMIUM] + free_pool = sorted( + [ + c + for c in or_cfgs + if str(c.get("billing_tier", "")).lower() == "free" + ], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_FREE] + # De-duplicate while preserving order: a cfg shouldn't fall in both + # tiers, but defensive code is cheap here. + seen_ids: set[int] = set() + selected: list[dict] = [] + for cfg in premium_pool + free_pool: + cid = int(cfg.get("id", 0)) + if cid in seen_ids: + continue + seen_ids.add(cid) + selected.append(cfg) + + if not selected: + return + + api_key = str(self._settings.get("api_key") or "") + semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY) + + async with httpx.AsyncClient( + timeout=_HEALTH_FETCH_TIMEOUT_SEC + ) as client: + results = await asyncio.gather( + *( + self._fetch_endpoints(client, semaphore, api_key, cfg) + for cfg in selected + ) + ) + + fail_count = sum(1 for _, _, err in results if err is not None) + fail_ratio = fail_count / len(results) if results else 0.0 + degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK + if degraded: + logger.warning( + "auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d " + "using_last_good_cache=true", + fail_ratio, + len(results), + ) + + # Per-cfg health update. + for cfg, endpoints, err in results: + model_name = str(cfg.get("model_name", "")) + if not degraded and err is None and endpoints is not None: + gated, h_score = aggregate_health(endpoints) + cfg["health_gated"] = bool(gated) + cfg["quality_score_health"] = h_score + self._health_cache[model_name] = { + "gated": bool(gated), + "score": h_score, + } + else: + cached = self._health_cache.get(model_name) + if cached is not None: + cfg["health_gated"] = bool(cached.get("gated", False)) + cfg["quality_score_health"] = cached.get("score") + # else: keep current values (initial defaults from + # _generate_configs / load_global_llm_configs). + + # Blend health into the final score for every OR cfg, including + # those outside the enriched top-N (they fall through to static). + gated_count = 0 + by_provider: dict[str, int] = {} + for cfg in or_cfgs: + static_q = int(cfg.get("quality_score_static") or 0) + h = cfg.get("quality_score_health") + if h is not None and not cfg.get("health_gated"): + blended = ( + _HEALTH_BLEND_WEIGHT * float(h) + + (1 - _HEALTH_BLEND_WEIGHT) * static_q + ) + cfg["quality_score"] = round(blended) + else: + cfg["quality_score"] = static_q + + if cfg.get("health_gated"): + gated_count += 1 + model_id = str(cfg.get("model_name", "")) + provider_slug = ( + model_id.split("/", 1)[0] if "/" in model_id else "unknown" + ) + by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1 + + if log_summary: + logger.info( + "auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f " + "total_enriched=%d", + gated_count, + dict(sorted(by_provider.items(), key=lambda kv: -kv[1])), + fail_ratio, + len(selected), + ) + + @staticmethod + async def _fetch_endpoints( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + api_key: str, + cfg: dict, + ) -> tuple[dict, list[dict] | None, Exception | None]: + """Fetch ``/api/v1/models/{id}/endpoints`` for one cfg. + + Returns ``(cfg, endpoints, err)`` so the caller can keep batched + results aligned with their cfgs without raising. + """ + model_id = str(cfg.get("model_name", "")) + if not model_id: + return cfg, None, ValueError("missing model_name") + + url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + async with semaphore: + try: + resp = await client.get(url, headers=headers) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + return cfg, None, exc + + payload = data.get("data") if isinstance(data, dict) else None + if not isinstance(payload, dict): + return cfg, None, ValueError("malformed endpoints payload") + endpoints = payload.get("endpoints") + if not isinstance(endpoints, list): + return cfg, [], None + return cfg, endpoints, None + async def _refresh_loop(self, interval_hours: float) -> None: interval_sec = interval_hours * 3600 while True: diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py new file mode 100644 index 000000000..1c74aa928 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -0,0 +1,331 @@ +"""Unit tests for the OpenRouter ``_enrich_health`` background task.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, +) +from app.services.quality_score import ( + _HEALTH_FAIL_RATIO_FALLBACK, +) + +pytestmark = pytest.mark.unit + + +def _or_cfg( + *, + cid: int, + model_name: str, + tier: str = "premium", + static_score: int = 50, +) -> dict: + return { + "id": cid, + "provider": "OPENROUTER", + "model_name": model_name, + "billing_tier": tier, + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_score, + "quality_score_health": None, + "quality_score": static_score, + "health_gated": False, + } + + +class _StubResponse: + def __init__(self, *, payload: dict, status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._payload + + +class _StubAsyncClient: + """Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``.""" + + def __init__(self, responder): + self._responder = responder + self.requests: list[str] = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url: str, headers: dict | None = None) -> _StubResponse: + self.requests.append(url) + return self._responder(url) + + +def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient: + """Replace ``httpx.AsyncClient`` for the duration of the test.""" + client = _StubAsyncClient(responder) + monkeypatch.setattr( + "app.services.openrouter_integration_service.httpx.AsyncClient", + lambda *_args, **_kwargs: client, + ) + return client + + +def _healthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + } + ] + } + } + + +def _unhealthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.62, + "uptime_last_5m": 0.50, + } + ] + } + } + + +# --------------------------------------------------------------------------- +# Bounded fan-out + happy path +# --------------------------------------------------------------------------- + + +async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch): + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="venice/dead-model", static_score=60), + ] + + def responder(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload=_healthy_payload()) + return _StubResponse(payload=_unhealthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {"api_key": ""} + await service._enrich_health(cfgs) + + healthy = next(c for c in cfgs if c["id"] == -1) + gated = next(c for c in cfgs if c["id"] == -2) + + assert healthy["health_gated"] is False + assert healthy["quality_score_health"] is not None + assert healthy["quality_score"] >= healthy["quality_score_static"] + + assert gated["health_gated"] is True + assert gated["quality_score"] == gated["quality_score_static"] + + +async def test_enrich_health_only_touches_or_provider(monkeypatch): + """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" + yaml_cfg = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score_static": 80, + "quality_score": 80, + "health_gated": False, + } + or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku") + + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg, or_cfg]) + + assert all("anthropic/claude-haiku" in r for r in requests) + # YAML cfg is untouched. + assert yaml_cfg["quality_score"] == 80 + assert yaml_cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Failure ratio fallback +# --------------------------------------------------------------------------- + + +async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high( + monkeypatch, +): + """If >= 25% of fetches fail, keep last-good cache instead of writing + partial data.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80), + _or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65), + _or_cfg(cid=-4, model_name="venice/something", static_score=50), + ] + + service = OpenRouterIntegrationService() + service._settings = {} + # Pre-seed last-good cache with a known-healthy snapshot. + service._health_cache = { + "anthropic/claude-haiku": {"gated": False, "score": 95.0}, + } + + def all_fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, all_fail) + await service._enrich_health(cfgs) + + # Above threshold ⇒ degraded; last-good cache wins for the cached cfg. + cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku") + assert cached_hit["quality_score_health"] == 95.0 + assert cached_hit["health_gated"] is False + # Confirm the threshold constant we're testing against is real. + assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0 + + +async def test_enrich_health_keeps_static_only_with_no_cache_and_failures( + monkeypatch, +): + """If a fetch fails and there's no last-good cache, the cfg keeps its + static-only ``quality_score`` and is *not* gated by default.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + ] + + def fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, fail) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + cfg = cfgs[0] + assert cfg["health_gated"] is False + assert cfg["quality_score"] == cfg["quality_score_static"] + assert cfg["quality_score_health"] is None + + +# --------------------------------------------------------------------------- +# Last-good cache: success populates, next failure reuses +# --------------------------------------------------------------------------- + + +async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure( + monkeypatch, +): + cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70) + + service = OpenRouterIntegrationService() + service._settings = {} + + def healthy(_url: str) -> _StubResponse: + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, healthy) + await service._enrich_health([cfg]) + + assert "anthropic/claude-haiku" in service._health_cache + cached_score = service._health_cache["anthropic/claude-haiku"]["score"] + assert cached_score is not None + + # Next cycle: enough other healthy cfgs so failure ratio stays below + # the 25% threshold even when this one fails individually. + other_cfgs = [ + _or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60) + for i in range(10) + ] + cfg["quality_score_health"] = None + cfg["quality_score"] = cfg["quality_score_static"] + + def mixed(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload={}, status_code=500) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, mixed) + await service._enrich_health([cfg, *other_cfgs]) + + assert cfg["quality_score_health"] == cached_score + assert cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Bounded fan-out: respects top-N caps +# --------------------------------------------------------------------------- + + +async def test_enrich_health_bounds_premium_fanout(monkeypatch): + """Top-N premium cap is honoured even when many cfgs are present.""" + from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM + + cfgs = [ + _or_cfg( + cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i + ) + for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20) + ] + + seen: list[str] = [] + + def responder(url: str) -> _StubResponse: + seen.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM + + +async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): + """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" + yaml_cfg: dict[str, Any] = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + } + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg]) + assert requests == [] From 4bef75d2986b0c46d79b3104dfa4f71dbba5c7fa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:38:53 +0530 Subject: [PATCH 278/299] feat(auto_pin): quality-aware tier-locked selection with health gate --- .../app/services/auto_model_pin_service.py | 56 ++- .../services/test_auto_model_pin_service.py | 336 ++++++++++++++++++ 2 files changed, 387 insertions(+), 5 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 1a2061492..94aa6b734 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -24,6 +24,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import config from app.db import NewChatThread +from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) @@ -49,8 +50,16 @@ def _is_usable_global_config(cfg: dict) -> bool: def _global_candidates() -> list[dict]: + """Return Auto-eligible global cfgs. + + Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime + below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers + can't be picked as the thread's pin. + """ candidates = [ - cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) + cfg + for cfg in config.GLOBAL_LLM_CONFIGS + if _is_usable_global_config(cfg) and not cfg.get("health_gated") ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -59,10 +68,26 @@ def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() -def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: +def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: + """Pick a config with quality-first ranking + deterministic spread. + + Tier policy is lock-first: prefer Tier A (operator-curated YAML) + cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no + Tier A cfg is eligible after upstream filters. Within the locked + pool, sort by ``quality_score`` and pick from the top-K via + ``SHA256(thread_id)`` so different new threads spread across the + best models without ever picking a low-ranked one. + + Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for + structured logging in the caller. + """ + tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")] + pool = tier_a if tier_a else eligible + pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) + top_k = pool[:_QUALITY_TOP_K] digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() - idx = int.from_bytes(digest[:8], "big") % len(candidates) - return candidates[idx] + idx = int.from_bytes(digest[:8], "big") % len(top_k) + return top_k[idx], len(top_k) def _to_uuid(user_id: str | UUID | None) -> UUID | None: @@ -150,6 +175,15 @@ async def resolve_or_get_pinned_llm_config_id( pinned_id, _tier_of(pinned_cfg), ) + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True", + thread_id, + pinned_id, + _tier_of(pinned_cfg), + pinned_cfg.get("auto_pin_tier", "?"), + int(pinned_cfg.get("quality_score") or 0), + ) return AutoPinResolution( resolved_llm_config_id=int(pinned_id), resolved_tier=_tier_of(pinned_cfg), @@ -176,7 +210,7 @@ async def resolve_or_get_pinned_llm_config_id( "Auto mode could not find an eligible LLM config for this user and quota state" ) - selected_cfg = _deterministic_pick(eligible, thread_id) + selected_cfg, top_k_size = _select_pin(eligible, thread_id) selected_id = int(selected_cfg["id"]) selected_tier = _tier_of(selected_cfg) @@ -211,6 +245,18 @@ async def resolve_or_get_pinned_llm_config_id( selected_tier, premium_eligible, ) + + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False", + thread_id, + selected_id, + selected_tier, + selected_cfg.get("auto_pin_tier", "?"), + int(selected_cfg.get("quality_score") or 0), + top_k_size, + ) + return AutoPinResolution( resolved_llm_config_id=selected_id, resolved_tier=selected_tier, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 2094ea6dd..be9d7f721 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -365,3 +365,339 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): assert result.resolved_llm_config_id == -2 assert session.thread.pinned_llm_config_id == -2 assert session.commit_count == 1 + + +# --------------------------------------------------------------------------- +# Quality-aware pin selection (Auto Fastest upgrade) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_health_gated_config_is_excluded_from_selection(monkeypatch): + """A cfg flagged ``health_gated`` must never be picked even if it has + the highest score among eligible cfgs.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 95, + "health_gated": True, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): + """Premium-eligible users with Tier A available should never spill to + Tier B even if a B cfg ranks higher by ``quality_score``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 70, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5", + "api_key": "k-or", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 95, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch): + """Free-only user with no Tier A free cfg should pick from Tier C.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash:free", + "api_key": "k-or", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_top_k_picks_only_high_score_models(monkeypatch): + """Different thread IDs should spread across top-K, never pick the + obvious low-quality cfg even when it sits in the candidate list.""" + from app.config import config + + high_score_cfgs = [ + { + "id": -i, + "provider": "AZURE_OPENAI", + "model_name": f"gpt-x-{i}", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + } + for i in range(1, 6) # 5 high-quality Tier A cfgs + ] + low_score_trap = { + "id": -99, + "provider": "AZURE_OPENAI", + "model_name": "tiny-legacy", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + "health_gated": False, + } + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + high_score_cfgs + [low_score_trap], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + high_score_ids = {c["id"] for c in high_score_cfgs} + seen = set() + for thread_id in range(1, 50): + session = _FakeSession(_thread()) + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + seen.add(result.resolved_llm_config_id) + assert result.resolved_llm_config_id != -99, ( + "low-score trap cfg should never be picked" + ) + assert result.resolved_llm_config_id in high_score_ids + + # Spread across at least a couple of top-K cfgs. + assert len(seen) > 1 + + +@pytest.mark.asyncio +async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): + """An *already* pinned cfg that later flips to ``health_gated`` should + still not be reused — gated cfgs are filtered out of the candidate + pool, which forces a repair to a healthy cfg. + + This guards the no-silent-tier-switch invariant: we don't keep using + a known-broken model just because the thread happened to be pinned + to it before the gate fired.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 50, + "health_gated": True, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): + """Existing pin reuse must short-circuit the new tier/score logic.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 50, # lower than -2 + "health_gated": False, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5-pro", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 99, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 From f65b3be1ce72e311dffd03de2d60e0fe73f2aef8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 00:57:52 +0530 Subject: [PATCH 279/299] feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection --- .../app/services/auto_model_pin_service.py | 64 ++- .../app/tasks/chat/stream_new_chat.py | 380 ++++++++++++++---- .../services/test_auto_model_pin_service.py | 112 ++++++ .../unit/test_stream_new_chat_contract.py | 16 + 4 files changed, 486 insertions(+), 86 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 94aa6b734..05a54b257 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -16,6 +16,8 @@ from __future__ import annotations import hashlib import logging +import threading +import time from dataclasses import dataclass from uuid import UUID @@ -31,6 +33,13 @@ logger = logging.getLogger(__name__) AUTO_FASTEST_ID = 0 AUTO_FASTEST_MODE = "auto_fastest" +_RUNTIME_COOLDOWN_SECONDS = 600 + +# In-memory runtime cooldown map for configs that recently hard-failed at +# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps +# the same unhealthy config from being reselected immediately during repair. +_runtime_cooldown_until: dict[int, float] = {} +_runtime_cooldown_lock = threading.Lock() @dataclass @@ -49,17 +58,68 @@ def _is_usable_global_config(cfg: dict) -> bool: ) +def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] + for cid in stale: + _runtime_cooldown_until.pop(cid, None) + + +def _is_runtime_cooled_down(config_id: int) -> bool: + with _runtime_cooldown_lock: + _prune_runtime_cooldowns() + return config_id in _runtime_cooldown_until + + +def mark_runtime_cooldown( + config_id: int, + *, + reason: str = "rate_limited", + cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS, +) -> None: + """Temporarily suppress a config from Auto selection. + + Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned + config that is currently unhealthy does not get immediately reused on the + same thread during repair. + """ + if cooldown_seconds <= 0: + cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS + until = time.time() + int(cooldown_seconds) + with _runtime_cooldown_lock: + _runtime_cooldown_until[int(config_id)] = until + _prune_runtime_cooldowns() + logger.info( + "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", + config_id, + reason, + cooldown_seconds, + ) + + +def clear_runtime_cooldown(config_id: int | None = None) -> None: + """Test/ops helper to clear runtime cooldown entries.""" + with _runtime_cooldown_lock: + if config_id is None: + _runtime_cooldown_until.clear() + return + _runtime_cooldown_until.pop(int(config_id), None) + + def _global_candidates() -> list[dict]: """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers - can't be picked as the thread's pin. + can't be picked as the thread's pin. Also excludes configs currently + in runtime cooldown (e.g. temporary 429 bursts). """ candidates = [ cfg for cfg in config.GLOBAL_LLM_CONFIGS - if _is_usable_global_config(cfg) and not cfg.get("health_gated") + if _is_usable_global_config(cfg) + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5abcb63eb..8f596927d 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -64,7 +64,10 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT -from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id +from app.services.auto_model_pin_service import ( + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None: return None +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _is_provider_rate_limited(exc: BaseException) -> bool: + """Best-effort detection for provider-side runtime throttling. + + Covers LiteLLM/OpenRouter shapes like: + - class name contains ``RateLimit`` + - nested payload ``{"error": {"code": 429}}`` + - nested payload ``{"error": {"type": "rate_limit_error"}}`` + """ + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + def _classify_stream_exception( exc: Exception, *, @@ -449,19 +506,7 @@ def _classify_stream_exception( None, ) - parsed = _parse_error_payload(raw) - provider_error_type = "" - if parsed: - top_type = parsed.get("type") - if isinstance(top_type, str): - provider_error_type = top_type.lower() - nested = parsed.get("error") - if isinstance(nested, dict): - nested_type = nested.get("type") - if isinstance(nested_type, str): - provider_error_type = nested_type.lower() - - if provider_error_type == "rate_limit_error": + if _is_provider_rate_limited(exc): return ( "rate_limited", "RATE_LIMITED", @@ -2671,54 +2716,144 @@ async def stream_new_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background task finishes - if title_task is not None and title_task.done() and not title_emitted: - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=input_state, + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking", + initial_step_id=initial_step_id, + initial_step_title=initial_title, + initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title + _first_event_logged = True + yield sse + + # Inject title update mid-stream as soon as the background + # task finishes. + if title_task is not None and title_task.done() and not title_emitted: + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter( + NewChatThread.id == chat_id + ) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update( + chat_id, generated_title + ) + title_emitted = True + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) + ) + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, ) - title_emitted = True + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + raise stream_exc + + # Title generation uses the initial llm object. After a runtime + # repin we keep the stream focused on response recovery and skip + # title generation for this turn. + if title_task is not None and not title_task.done(): + title_task.cancel() + title_task = None + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_new_chat] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", @@ -3265,31 +3400,108 @@ async def stream_resume_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume={"decisions": decisions}), - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking-resume", - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=Command(resume={"decisions": decisions}), + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True + yield sse + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) ) - _first_event_logged = True - yield sse + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + ) + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + raise stream_exc + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_resume] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index be9d7f721..8261fdfe0 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,12 +6,21 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( + clear_runtime_cooldown, + mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) pytestmark = pytest.mark.unit +@pytest.fixture(autouse=True) +def _clear_runtime_cooldown_map(): + clear_runtime_cooldown() + yield + clear_runtime_cooldown() + + @dataclass class _FakeQuotaResult: allowed: bool @@ -701,3 +710,106 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): assert result.resolved_llm_config_id == -1 assert result.from_existing_pin is True assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): + """A runtime-cooled config should be excluded from candidate reuse. + + This enables one-shot recovery from transient provider 429 bursts: we can + mark the pinned cfg as cooled down and force a repair to another eligible + cfg on the next resolution. + """ + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on healthy pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + clear_runtime_cooldown(-1) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 5e6ad6abd..ed69ca348 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -159,6 +159,22 @@ def test_stream_exception_classifies_rate_limited(): assert extra is None +def test_stream_exception_classifies_openrouter_429_payload(): + exc = Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' + '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' + ) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + assert extra is None + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( From 25ccc959cf59018c3937be22b23ffc7a35fb7391 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 01:35:30 +0530 Subject: [PATCH 280/299] feat(busy_mutex): enhance thread lock management to prevent stale middleware interference --- .../agents/new_chat/middleware/busy_mutex.py | 37 ++++++++++--- .../app/services/auto_model_pin_service.py | 10 +++- .../app/tasks/chat/stream_new_chat.py | 9 ++++ .../unit/agents/new_chat/test_busy_mutex.py | 34 ++++++++++++ .../services/test_auto_model_pin_service.py | 53 +++++++++++++++++++ 5 files changed, 134 insertions(+), 9 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index d61a56533..06a27bc96 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -61,6 +61,9 @@ class _ThreadLockManager: self._cancel_events: dict[str, asyncio.Event] = {} self._cancel_requested_at_ms: dict[str, int] = {} self._cancel_attempt_count: dict[str, int] = {} + # Monotonic per-thread epoch used to prevent stale middleware + # teardown from releasing a newer turn's lock. + self._turn_epoch: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -107,6 +110,14 @@ class _ThreadLockManager: self._cancel_requested_at_ms.pop(thread_id, None) self._cancel_attempt_count.pop(thread_id, None) + def bump_turn_epoch(self, thread_id: str) -> int: + epoch = self._turn_epoch.get(thread_id, 0) + 1 + self._turn_epoch[thread_id] = epoch + return epoch + + def current_turn_epoch(self, thread_id: str) -> int: + return self._turn_epoch.get(thread_id, 0) + def end_turn(self, thread_id: str) -> None: """Best-effort terminal cleanup for a thread turn. @@ -114,6 +125,10 @@ class _ThreadLockManager: finally-blocks where middleware teardown might be skipped due to abort or disconnect edge-cases. """ + # Invalidate any in-flight middleware holder first. This guarantees a + # stale ``aafter_agent`` from an older attempt cannot unlock a newer + # retry that already acquired the lock for the same thread. + self.bump_turn_epoch(thread_id) lock = self._locks.get(thread_id) if lock is not None and lock.locked(): lock.release() @@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo super().__init__() self._require_thread_id = require_thread_id self.tools = [] - # Per-call locks owned by this middleware. We track them as - # an instance attribute so ``aafter_agent`` knows which lock - # to release. - self._held_locks: dict[str, asyncio.Lock] = {} + # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent`` + # only releases when its epoch still matches the manager's current + # epoch for the thread, preventing stale unlock races. + self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {} @staticmethod def _thread_id(runtime: Runtime[ContextT]) -> str | None: @@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo if lock.locked(): raise BusyError(request_id=thread_id) await lock.acquire() - self._held_locks[thread_id] = lock + epoch = manager.bump_turn_epoch(thread_id) + self._held_locks[thread_id] = (lock, epoch) # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo thread_id = self._thread_id(runtime) if thread_id is None: return None - lock = self._held_locks.pop(thread_id, None) - if lock is not None and lock.locked(): + held = self._held_locks.pop(thread_id, None) + if held is None: + return None + lock, held_epoch = held + if held_epoch != manager.current_turn_epoch(thread_id): + # Stale teardown from an older attempt (e.g. runtime-recovery path + # already advanced epoch). Do not touch current lock/cancel state. + return None + if lock.locked(): lock.release() # Always clear cancel event between turns so a stale signal # doesn't leak into the next request. diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 05a54b257..f6a223866 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -179,6 +179,7 @@ async def resolve_or_get_pinned_llm_config_id( user_id: str | UUID | None, selected_llm_config_id: int, force_repin_free: bool = False, + exclude_config_ids: set[int] | None = None, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist the pin. @@ -214,9 +215,14 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=False, ) - candidates = _global_candidates() + excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} + candidates = [ + c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + ] if not candidates: - raise ValueError("No usable global LLM configs are available for Auto mode") + raise ValueError( + "No usable global LLM configs are available for Auto mode" + ) candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 8f596927d..dbfd5e2ea 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2784,6 +2784,10 @@ async def stream_new_chat( runtime_rate_limit_recovered = True previous_config_id = llm_config_id + # The failed attempt may still hold the per-thread busy mutex + # (middleware teardown can lag behind raised provider errors). + # Force release before we retry within the same request. + end_turn(str(chat_id)) mark_runtime_cooldown( previous_config_id, reason="provider_rate_limited", @@ -2796,6 +2800,7 @@ async def stream_new_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) ).resolved_llm_config_id @@ -3442,6 +3447,9 @@ async def stream_resume_chat( runtime_rate_limit_recovered = True previous_config_id = llm_config_id + # Ensure the same-request recovery retry does not trip the + # BusyMutex lock retained by the failed attempt. + end_turn(str(chat_id)) mark_runtime_cooldown( previous_config_id, reason="provider_rate_limited", @@ -3453,6 +3461,7 @@ async def stream_resume_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) ).resolved_llm_config_id diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index c923dc499..f0161f605 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None: assert not manager.lock_for(thread_id).locked() assert not get_cancel_event(thread_id).is_set() assert is_cancel_requested(thread_id) is False + + +@pytest.mark.asyncio +async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None: + """A stale aafter call from attempt A must not unlock attempt B. + + Repro flow: + 1) attempt A acquires thread lock + 2) forced end_turn clears A so retry can proceed + 3) attempt B acquires same thread lock + 4) stale attempt-A aafter runs late + + Expected: B lock remains held. + """ + thread_id = "stale-aafter-lock" + runtime = _Runtime(thread_id) + attempt_a = BusyMutexMiddleware() + attempt_b = BusyMutexMiddleware() + + await attempt_a.abefore_agent({}, runtime) + lock = manager.lock_for(thread_id) + assert lock.locked() + + end_turn(thread_id) + assert not lock.locked() + + await attempt_b.abefore_agent({}, runtime) + assert lock.locked() + + # Stale cleanup from attempt A must not release attempt B's lock. + await attempt_a.aafter_agent({}, runtime) + assert lock.locked() + + await attempt_b.aafter_agent({}, runtime) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 8261fdfe0..8696a8829 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -813,3 +813,56 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): ) assert result.resolved_llm_config_id == -1 assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch): + """Runtime retry should never repin the just-failed config.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + exclude_config_ids={-1}, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False From 14686cdf829e62b4a5b62f088faf462948aaa416 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:07:16 +0530 Subject: [PATCH 281/299] feat(auto_pin): add short-TTL healthy-status cache for preflight reuse --- .../app/services/auto_model_pin_service.py | 57 +++++++++++++++++++ .../services/test_auto_model_pin_service.py | 53 +++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index f6a223866..b2acd6f56 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -34,6 +34,7 @@ logger = logging.getLogger(__name__) AUTO_FASTEST_ID = 0 AUTO_FASTEST_MODE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 +_HEALTHY_TTL_SECONDS = 45 # In-memory runtime cooldown map for configs that recently hard-failed at # provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps @@ -41,6 +42,13 @@ _RUNTIME_COOLDOWN_SECONDS = 600 _runtime_cooldown_until: dict[int, float] = {} _runtime_cooldown_lock = threading.Lock() +# Short-TTL "recently healthy" cache for configs that just passed a runtime +# preflight ping. Lets back-to-back turns on the same model skip the probe +# without eroding correctness — entries auto-expire and are wiped any time +# the same config is cooled down or the OR catalogue is refreshed. +_healthy_until: dict[int, float] = {} +_healthy_lock = threading.Lock() + @dataclass class AutoPinResolution: @@ -89,6 +97,9 @@ def mark_runtime_cooldown( with _runtime_cooldown_lock: _runtime_cooldown_until[int(config_id)] = until _prune_runtime_cooldowns() + # A cooled cfg can never be "recently healthy"; drop any stale credit so + # the next turn that resolves to it (after cooldown) re-runs preflight. + clear_healthy(int(config_id)) logger.info( "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", config_id, @@ -106,6 +117,52 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None: _runtime_cooldown_until.pop(int(config_id), None) +def _prune_healthy(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _healthy_until.items() if until <= now] + for cid in stale: + _healthy_until.pop(cid, None) + + +def is_recently_healthy(config_id: int) -> bool: + """Return True if ``config_id`` passed preflight within the TTL window.""" + with _healthy_lock: + _prune_healthy() + return int(config_id) in _healthy_until + + +def mark_healthy( + config_id: int, + *, + ttl_seconds: int = _HEALTHY_TTL_SECONDS, +) -> None: + """Record that ``config_id`` just passed a preflight probe. + + Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The + healthy state is intentionally process-local — it's a latency hint, not a + correctness primitive — so multi-worker drift is acceptable. + """ + if ttl_seconds <= 0: + ttl_seconds = _HEALTHY_TTL_SECONDS + until = time.time() + int(ttl_seconds) + with _healthy_lock: + _healthy_until[int(config_id)] = until + _prune_healthy() + + +def clear_healthy(config_id: int | None = None) -> None: + """Drop one (or all) healthy-cache entries. + + Called from runtime cooldown and OR catalogue refresh so a freshly cooled + or replaced config never carries stale "healthy" credit. + """ + with _healthy_lock: + if config_id is None: + _healthy_until.clear() + return + _healthy_until.pop(int(config_id), None) + + def _global_candidates() -> list[dict]: """Return Auto-eligible global cfgs. diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 8696a8829..d333f0b7a 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,7 +6,10 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( + clear_healthy, clear_runtime_cooldown, + is_recently_healthy, + mark_healthy, mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) @@ -17,8 +20,10 @@ pytestmark = pytest.mark.unit @pytest.fixture(autouse=True) def _clear_runtime_cooldown_map(): clear_runtime_cooldown() + clear_healthy() yield clear_runtime_cooldown() + clear_healthy() @dataclass @@ -866,3 +871,51 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa ) assert result.resolved_llm_config_id == -2 assert result.from_existing_pin is False + + +# --------------------------------------------------------------------------- +# Healthy-status cache (preflight TTL companion) +# --------------------------------------------------------------------------- + + +def test_mark_healthy_then_is_recently_healthy_true_within_ttl(): + mark_healthy(-42, ttl_seconds=60) + assert is_recently_healthy(-42) is True + + +def test_healthy_expires_after_ttl(monkeypatch): + import app.services.auto_model_pin_service as svc + + real_time = svc.time.time + base = real_time() + + monkeypatch.setattr(svc.time, "time", lambda: base) + mark_healthy(-7, ttl_seconds=10) + assert is_recently_healthy(-7) is True + + monkeypatch.setattr(svc.time, "time", lambda: base + 11) + assert is_recently_healthy(-7) is False + + +def test_mark_runtime_cooldown_invalidates_healthy_cache(): + mark_healthy(-9, ttl_seconds=60) + assert is_recently_healthy(-9) is True + + mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60) + assert is_recently_healthy(-9) is False + + +def test_clear_healthy_removes_single_entry(): + mark_healthy(-11, ttl_seconds=60) + mark_healthy(-12, ttl_seconds=60) + clear_healthy(-11) + assert is_recently_healthy(-11) is False + assert is_recently_healthy(-12) is True + + +def test_clear_healthy_no_args_drops_all_entries(): + mark_healthy(-21, ttl_seconds=60) + mark_healthy(-22, ttl_seconds=60) + clear_healthy() + assert is_recently_healthy(-21) is False + assert is_recently_healthy(-22) is False From 2764fa5e30185c3e22f59a10df19c7db7d0a25bd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:07:30 +0530 Subject: [PATCH 282/299] feat(openrouter): clear healthy-status cache on catalogue refresh --- .../app/services/openrouter_integration_service.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 9c3eaa5ea..67dbb6690 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -382,6 +382,18 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id + # Catalogue churn invalidates per-config "recently healthy" credit + # earned by the previous turn's preflight. Drop the whole table so + # the next turn re-probes against the freshly loaded configs. + try: + from app.services.auto_model_pin_service import clear_healthy + + clear_healthy() + except Exception: + logger.debug( + "OpenRouter refresh: clear_healthy import skipped", exc_info=True + ) + tier_counts = self._tier_counts(new_configs) logger.info( "OpenRouter refresh: updated to %d models (free=%d, premium=%d)", From 7c1c394fe4768c05babc0330e2f8955e82167046 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:07:44 +0530 Subject: [PATCH 283/299] feat(stream_new_chat): add lightweight LLM preflight probe for auto-pin --- .../unit/test_stream_new_chat_contract.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index ed69ca348..6a1b4c13b 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -175,6 +175,68 @@ def test_stream_exception_classifies_openrouter_429_payload(): assert extra is None +@pytest.mark.asyncio +async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch): + """``_preflight_llm`` is best-effort. + + - On rate-limit shaped exceptions (provider 429) it MUST re-raise so the + caller can drive the cooldown/repin branch. + - On any other transient failure it MUST swallow the error so the normal + stream path continues without surfacing preflight noise to the user. + """ + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + class _RateLimitedExc(Exception): + """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers.""" + + rate_calls: list[dict] = [] + other_calls: list[dict] = [] + + async def _fake_acompletion_429(**kwargs): + rate_calls.append(kwargs) + raise _RateLimitedExc("simulated 429") + + async def _fake_acompletion_other(**kwargs): + other_calls.append(kwargs) + raise RuntimeError("some unrelated transient failure") + + fake_llm = SimpleNamespace( + model="openrouter/google/gemma-4-31b-it:free", + api_key="test", + api_base=None, + ) + + import litellm # type: ignore[import-not-found] + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429) + with pytest.raises(_RateLimitedExc): + await _preflight_llm(fake_llm) + assert len(rate_calls) == 1 + assert rate_calls[0]["max_tokens"] == 1 + assert rate_calls[0]["stream"] is False + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other) + # MUST NOT raise: non-rate-limit failures are swallowed. + await _preflight_llm(fake_llm) + assert len(other_calls) == 1 + + +@pytest.mark.asyncio +async def test_preflight_skipped_for_auto_router_model(): + """Router-mode ``model='auto'`` has no single deployment to ping; the + LiteLLM router itself owns per-deployment rate-limit accounting, so the + preflight helper must short-circuit instead of issuing a probe.""" + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None) + # Should return without raising or making any LiteLLM call. + await _preflight_llm(fake_llm) + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( From 789d8ce62ed173a8a2e98b1fe3d9a14f620beb69 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:08:34 +0530 Subject: [PATCH 284/299] feat(stream_new_chat): wire preflight + early repin into auto-mode flow --- .../app/tasks/chat/stream_new_chat.py | 215 ++++++++++++++++++ 1 file changed, 215 insertions(+) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index dbfd5e2ea..07d14afeb 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -65,6 +65,8 @@ from app.db import ( ) from app.prompts import TITLE_GENERATION_PROMPT from app.services.auto_model_pin_service import ( + is_recently_healthy, + mark_healthy, mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) @@ -471,6 +473,54 @@ def _is_provider_rate_limited(exc: BaseException) -> bool: ) +_PREFLIGHT_TIMEOUT_SEC: float = 2.5 +_PREFLIGHT_MAX_TOKENS: int = 1 + + +async def _preflight_llm(llm: Any) -> None: + """Issue a minimal completion to confirm the pinned model isn't 429'ing. + + Used before agent build / planner / classifier / title-gen so a known-bad + free OpenRouter deployment is detected and repinned before it cascades + into multiple wasted internal calls. The probe is intentionally cheap: + one token, low timeout, tagged ``surfsense:internal`` so token tracking + and SSE pipelines treat it as overhead rather than user output. + + Raises the original exception when the provider responds with a + rate-limit-shaped error so the caller can drive the cooldown/repin + branch via :func:`_is_provider_rate_limited`. Other transient failures + are swallowed — the caller continues to the normal stream path and the + in-stream recovery loop remains the safety net. + """ + from litellm import acompletion + + model = getattr(llm, "model", None) + if not model or model == "auto": + # Auto-mode router doesn't have a single deployment to ping; the + # router itself handles per-deployment rate-limit accounting. + return + + try: + await acompletion( + model=model, + messages=[{"role": "user", "content": "ping"}], + api_key=getattr(llm, "api_key", None), + api_base=getattr(llm, "api_base", None), + max_tokens=_PREFLIGHT_MAX_TOKENS, + timeout=_PREFLIGHT_TIMEOUT_SEC, + stream=False, + metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]}, + ) + except Exception as exc: + if _is_provider_rate_limited(exc): + raise + logging.getLogger(__name__).debug( + "auto_pin_preflight non_rate_limit_error model=%s err=%s", + model, + exc, + ) + + def _classify_stream_exception( exc: Exception, *, @@ -2371,6 +2421,92 @@ async def stream_new_chat( yield streaming_service.format_done() return + # Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs + # (negative ids selected via ``resolve_or_get_pinned_llm_config_id``) + # whose health hasn't already been confirmed within the TTL window. + # Detecting a 429 here lets us repin BEFORE the planner/classifier/ + # title-generation LLM calls fan out and each independently hit the + # same upstream rate limit. + if ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ): + _t_preflight = time.perf_counter() + try: + await _preflight_llm(llm) + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_new_chat] auto_pin_preflight ok config_id=%s " + "took=%.3fs", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + # Trust the freshly-resolved cfg for the remainder of this + # turn rather than recursing into another preflight; the + # in-stream 429 recovery loop is still in place as the + # safety net if even this fallback hits an upstream cap. + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + # Create connector service _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) @@ -3327,6 +3463,85 @@ async def stream_resume_chat( yield streaming_service.format_done() return + # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: + # one cheap probe before the agent is rebuilt so a 429'd pin gets + # repinned without burning planner/classifier/title calls first. + if ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ): + _t_preflight = time.perf_counter() + try: + await _preflight_llm(llm) + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_resume] auto_pin_preflight ok config_id=%s " + "took=%.3fs", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) From d14fed43c6f92e03907974c3ebb6318d77d3a0f9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:45:27 +0530 Subject: [PATCH 285/299] feat(documents): add endpoint to retrieve document by virtual path --- .../app/routes/documents_routes.py | 45 ++++++ .../app/tasks/chat/stream_new_chat.py | 24 +-- .../unit/test_stream_new_chat_contract.py | 34 ++++ .../components/assistant-ui/markdown-text.tsx | 150 ++++++++++++------ .../lib/apis/documents-api.service.ts | 12 ++ 5 files changed, 206 insertions(+), 59 deletions(-) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index f558481cf..f1ca3b6bf 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -745,6 +745,51 @@ async def search_document_titles( ) from e +@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead) +async def get_document_by_virtual_path( + search_space_id: int, + virtual_path: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Resolve a knowledge-base document id by exact virtual path.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + + result = await session.execute( + select( + Document.id, + Document.title, + Document.document_type, + ).filter( + Document.search_space_id == search_space_id, + Document.document_metadata["virtual_path"].as_string() == virtual_path, + ) + ) + row = result.first() + if row is None: + raise HTTPException(status_code=404, detail="Document not found") + + return DocumentTitleRead( + id=row.id, + title=row.title, + document_type=row.document_type, + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to resolve document by virtual path: {e!s}", + ) from e + + @router.get("/documents/status", response_model=DocumentStatusBatchResponse) async def get_documents_status( search_space_id: int, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 07d14afeb..53f237f06 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -304,20 +304,17 @@ def _tool_output_has_error(tool_output: Any) -> bool: return False -def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None: +def _extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: if isinstance(tool_output, dict): path_value = tool_output.get("path") if isinstance(path_value, str) and path_value.strip(): return path_value.strip() - text = _tool_output_to_text(tool_output) - if tool_name == "write_file": - match = re.search(r"Updated file\s+(.+)$", text.strip()) - if match: - return match.group(1).strip() - if tool_name == "edit_file": - match = re.search(r"in '([^']+)'", text) - if match: - return match.group(1).strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() return None @@ -714,6 +711,7 @@ async def _stream_agent_events( # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + file_path_by_run: dict[str, str] = {} # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` # is keyed by the chunk's ``index`` field — LangChain @@ -892,6 +890,10 @@ async def _stream_agent_events( tool_input = event.get("data", {}).get("input", {}) if tool_name in ("write_file", "edit_file"): result.write_attempted = True + if isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip() and run_id: + file_path_by_run[run_id] = file_path.strip() if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -1298,6 +1300,7 @@ async def _stream_agent_events( run_id = event.get("run_id", "") tool_name = event.get("name", "unknown_tool") raw_output = event.get("data", {}).get("output", "") + staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None if tool_name == "update_memory": called_update_memory = True @@ -1811,6 +1814,7 @@ async def _stream_agent_events( resolved_path = _extract_resolved_file_path( tool_name=tool_name, tool_output=tool_output, + tool_input={"file_path": staged_file_path} if staged_file_path else None, ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 6a1b4c13b..3676601f4 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -13,6 +13,7 @@ from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, _contract_enforcement_active, + _extract_resolved_file_path, _evaluate_file_contract_outcome, _log_chat_stream_error, _tool_output_has_error, @@ -28,6 +29,39 @@ def test_tool_output_error_detection(): assert not _tool_output_has_error({"result": "Updated file /notes.md"}) +def test_extract_resolved_file_path_prefers_structured_path(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"status": "completed", "path": "/docs/note.md"}, + tool_input=None, + ) + == "/docs/note.md" + ) + + +def test_extract_resolved_file_path_falls_back_to_tool_input(): + assert ( + _extract_resolved_file_path( + tool_name="edit_file", + tool_output={"status": "completed", "result": "updated"}, + tool_input={"file_path": "/docs/edited.md"}, + ) + == "/docs/edited.md" + ) + + +def test_extract_resolved_file_path_does_not_parse_result_text(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"result": "Updated file /docs/from-text.md"}, + tool_input=None, + ) + is None + ) + + def test_file_write_contract_outcome_reasons(): result = StreamResult(intent_detected="file_write") passed, reason = _evaluate_file_contract_outcome(result) diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 4842e5979..bfbc3a423 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -30,8 +30,10 @@ import { TableRow, } from "@/components/ui/table"; import { useElectronAPI } from "@/hooks/use-platform"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; +import { toast } from "sonner"; function MarkdownCodeBlockSkeleton() { return ( @@ -194,6 +196,89 @@ function isVirtualFilePathToken(value: string): boolean { return segments.length >= 2; } +function isStandaloneDocumentsPathText(node: ReactNode): string | null { + if (typeof node !== "string") return null; + const value = node.trim(); + if (!value.startsWith("/documents/")) return null; + if (value.includes(" ")) return null; + const normalized = value.replace(/\/+$/, ""); + const leaf = normalized.split("/").filter(Boolean).at(-1) ?? ""; + if (!leaf || !leaf.includes(".")) return null; + return value; +} + +function FilePathLink({ + path, + className, +}: { + path: string; + className?: string; +}) { + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const params = useParams(); + const electronAPI = useElectronAPI(); + const searchSpaceIdParam = params?.search_space_id; + const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) + ? Number(searchSpaceIdParam[0]) + : Number(searchSpaceIdParam); + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) ? parsedSearchSpaceId : undefined; + + return ( + <button + type="button" + className={cn( + "cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80", + className + )} + onClick={(event) => { + event.preventDefault(); + event.stopPropagation(); + void (async () => { + if (electronAPI) { + let resolvedLocalPath = path; + if (electronAPI.getAgentFilesystemMounts) { + try { + const mounts = (await electronAPI.getAgentFilesystemMounts( + resolvedSearchSpaceId + )) as AgentFilesystemMount[]; + resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts); + } catch { + // Fall back to the raw path if mount lookup fails. + } + } + openEditorPanel({ + kind: "local_file", + localFilePath: resolvedLocalPath, + title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, + searchSpaceId: resolvedSearchSpaceId, + }); + return; + } + + if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return; + try { + const doc = await documentsApiService.getDocumentByVirtualPath({ + search_space_id: resolvedSearchSpaceId, + virtual_path: path, + }); + openEditorPanel({ + kind: "document", + documentId: doc.id, + searchSpaceId: resolvedSearchSpaceId, + title: doc.title, + }); + } catch { + toast.error("Document not found in knowledge base."); + } + })(); + }} + title="Open in editor panel" + > + {path} + </button> + ); +} + function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { if (!src) return null; @@ -311,9 +396,14 @@ const defaultComponents = memoizeMarkdownComponents({ }, p: function P({ className, children, ...props }) { const urlMap = useCitationUrlMap(); + const standalonePath = isStandaloneDocumentsPathText(children); return ( <p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}> - {processChildrenWithCitations(children, urlMap)} + {standalonePath ? ( + <FilePathLink path={standalonePath} /> + ) : ( + processChildrenWithCitations(children, urlMap) + )} </p> ); }, @@ -400,8 +490,6 @@ const defaultComponents = memoizeMarkdownComponents({ code: function Code({ className, children, ...props }) { const isCodeBlock = useIsMarkdownCodeBlock(); const { resolvedTheme } = useTheme(); - const openEditorPanel = useSetAtom(openEditorPanelAtom); - const params = useParams(); const electronAPI = useElectronAPI(); const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text"; const codeString = String(children).replace(/\n$/, ""); @@ -418,53 +506,17 @@ const defaultComponents = memoizeMarkdownComponents({ const isLikelyFolder = inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes("."); const isLocalPath = - !!electronAPI && - isVirtualFilePathToken(inlineValue) && - !inlineValue.startsWith("//") && - !isLikelyFolder; - const displayLocalPath = inlineValue.replace(/^\/+/, ""); - const searchSpaceIdParam = params?.search_space_id; - const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) - ? Number(searchSpaceIdParam[0]) - : Number(searchSpaceIdParam); + (isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder && + !!electronAPI) || + (isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder && + !electronAPI && + inlineValue.startsWith("/documents/")); if (isLocalPath) { - return ( - <button - type="button" - className={cn( - "cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80" - )} - onClick={(event) => { - event.preventDefault(); - event.stopPropagation(); - void (async () => { - let resolvedLocalPath = inlineValue; - const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) - ? parsedSearchSpaceId - : undefined; - if (electronAPI?.getAgentFilesystemMounts) { - try { - const mounts = (await electronAPI.getAgentFilesystemMounts( - resolvedSearchSpaceId - )) as AgentFilesystemMount[]; - resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts); - } catch { - // Fall back to the raw inline path if mount lookup fails. - } - } - openEditorPanel({ - kind: "local_file", - localFilePath: resolvedLocalPath, - title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, - searchSpaceId: resolvedSearchSpaceId, - }); - })(); - }} - title="Open in editor panel" - > - {displayLocalPath} - </button> - ); + return <FilePathLink path={inlineValue} className="text-[0.9em]" />; } return ( <code diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts index 0cd81c0b7..949e3b29f 100644 --- a/surfsense_web/lib/apis/documents-api.service.ts +++ b/surfsense_web/lib/apis/documents-api.service.ts @@ -28,6 +28,7 @@ import { getSurfsenseDocsRequest, getSurfsenseDocsResponse, type SearchDocumentsRequest, + documentTitleRead, type SearchDocumentTitlesRequest, searchDocumentsRequest, searchDocumentsResponse, @@ -269,6 +270,17 @@ class DocumentsApiService { ); }; + getDocumentByVirtualPath = async (request: { + search_space_id: number; + virtual_path: string; + }) => { + const params = new URLSearchParams({ + search_space_id: String(request.search_space_id), + virtual_path: request.virtual_path, + }); + return baseApiService.get(`/api/v1/documents/by-virtual-path?${params.toString()}`, documentTitleRead); + }; + /** * Get document type counts */ From e9d964514bdd1585f051616c90db924978341f26 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:31:03 +0530 Subject: [PATCH 286/299] feat(alembic): add user table to zero_publication for selective replication of usage metrics --- .../139_add_user_to_zero_publication.py | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py new file mode 100644 index 000000000..5b8bc29b0 --- /dev/null +++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py @@ -0,0 +1,158 @@ +"""add user table to zero_publication with column list + +Adds the "user" table to zero_publication with a column-list publication +so that only the 5 fields driving the live usage meters are replicated +through WAL -> zero-cache -> browser IndexedDB: + + id, pages_limit, pages_used, + premium_tokens_limit, premium_tokens_used + +Sensitive columns (hashed_password, email, oauth_account, display_name, +avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT +included in the publication, so they never enter WAL replication. + +Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency +(it is already DEFAULT today since "user" was never in the +TABLES_WITH_FULL_IDENTITY list of migration 117). + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Revision ID: 139 +Revises: 138 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "139" +down_revision: str | None = "138" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Document column list as left by migration 117. Must match exactly. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Five fields needed by the live usage meters (sidebar Tokens/Pages, +# Buy Tokens content). Keep this list narrow on purpose: anything added +# here flows into WAL and IndexedDB for every connected browser. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl(documents_has_zero_ver: bool, user_has_zero_ver: bool) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state" + ) + + +def upgrade() -> None: + conn = op.get_bind() + # asyncpg requires LOCK TABLE inside a transaction block. Alembic already + # opened one via context.begin_transaction(), but the driver still errors + # unless we use an explicit SAVEPOINT (nested transaction) for this block. + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of + # migration 117, so this is already DEFAULT. Re-assert anyway so + # the column-list publication stays valid (DEFAULT identity only + # requires the PK to be in the column list). + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + + conn.execute( + sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver)) + ) + + +def downgrade() -> None: + conn = op.get_bind() + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + documents_has_zero_ver = _has_zero_version(conn, "documents") + conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver))) From 05eef5a7db42f215fdbcc6115fbe609641b72c7f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:31:50 +0530 Subject: [PATCH 287/299] feat(zero): add userTable + queries.user.me() synced query --- surfsense_web/zero/queries/index.ts | 2 ++ surfsense_web/zero/queries/user.ts | 11 +++++++++++ surfsense_web/zero/schema/index.ts | 2 ++ surfsense_web/zero/schema/user.ts | 11 +++++++++++ 4 files changed, 26 insertions(+) create mode 100644 surfsense_web/zero/queries/user.ts create mode 100644 surfsense_web/zero/schema/user.ts diff --git a/surfsense_web/zero/queries/index.ts b/surfsense_web/zero/queries/index.ts index bc332114e..fbf1bd76e 100644 --- a/surfsense_web/zero/queries/index.ts +++ b/surfsense_web/zero/queries/index.ts @@ -3,6 +3,7 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat"; import { connectorQueries, documentQueries } from "./documents"; import { folderQueries } from "./folders"; import { notificationQueries } from "./inbox"; +import { userQueries } from "./user"; export const queries = defineQueries({ notifications: notificationQueries, @@ -12,4 +13,5 @@ export const queries = defineQueries({ messages: messageQueries, comments: commentQueries, chatSession: chatSessionQueries, + user: userQueries, }); diff --git a/surfsense_web/zero/queries/user.ts b/surfsense_web/zero/queries/user.ts new file mode 100644 index 000000000..30e71a482 --- /dev/null +++ b/surfsense_web/zero/queries/user.ts @@ -0,0 +1,11 @@ +import { defineQuery } from "@rocicorp/zero"; +import { z } from "zod"; +import { zql } from "../schema/index"; + +export const userQueries = { + me: defineQuery(z.object({}), ({ ctx }) => { + const userId = ctx?.userId; + if (!userId) return zql.user.where("id", "__none__").one(); + return zql.user.where("id", userId).one(); + }), +}; diff --git a/surfsense_web/zero/schema/index.ts b/surfsense_web/zero/schema/index.ts index bba561580..3cca0f24a 100644 --- a/surfsense_web/zero/schema/index.ts +++ b/surfsense_web/zero/schema/index.ts @@ -3,6 +3,7 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./ import { documentTable, searchSourceConnectorTable } from "./documents"; import { folderTable } from "./folders"; import { notificationTable } from "./inbox"; +import { userTable } from "./user"; const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({ message: one({ @@ -34,6 +35,7 @@ export const schema = createSchema({ newChatMessageTable, chatCommentTable, chatSessionStateTable, + userTable, ], relationships: [chatCommentRelationships, newChatMessageRelationships], }); diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts new file mode 100644 index 000000000..0e6234db5 --- /dev/null +++ b/surfsense_web/zero/schema/user.ts @@ -0,0 +1,11 @@ +import { number, string, table } from "@rocicorp/zero"; + +export const userTable = table("user") + .columns({ + id: string(), + pagesLimit: number().from("pages_limit"), + pagesUsed: number().from("pages_used"), + premiumTokensLimit: number().from("premium_tokens_limit"), + premiumTokensUsed: number().from("premium_tokens_used"), + }) + .primaryKey("id"); From 2a14c0528251e03a8e2ecff92c558e1628af5f27 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:05 +0530 Subject: [PATCH 288/299] feat(sidebar): live premium tokens meter via Zero --- .../ui/sidebar/PremiumTokenUsageDisplay.tsx | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx index a4d760dba..a3f028858 100644 --- a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx +++ b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx @@ -1,23 +1,18 @@ "use client"; -import { useQuery } from "@tanstack/react-query"; +import { useQuery } from "@rocicorp/zero/react"; import { Progress } from "@/components/ui/progress"; import { useIsAnonymous } from "@/contexts/anonymous-mode"; -import { stripeApiService } from "@/lib/apis/stripe-api.service"; +import { queries } from "@/zero/queries"; export function PremiumTokenUsageDisplay() { const isAnonymous = useIsAnonymous(); - const { data: tokenStatus } = useQuery({ - queryKey: ["token-status"], - queryFn: () => stripeApiService.getTokenStatus(), - staleTime: 60_000, - enabled: !isAnonymous, - }); + const [me] = useQuery(queries.user.me({})); - if (!tokenStatus) return null; + if (isAnonymous || !me) return null; const usagePercentage = Math.min( - (tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100, + (me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100, 100 ); @@ -31,8 +26,7 @@ export function PremiumTokenUsageDisplay() { <div className="space-y-1.5"> <div className="flex justify-between items-center text-xs"> <span className="text-muted-foreground"> - {formatTokens(tokenStatus.premium_tokens_used)} /{" "} - {formatTokens(tokenStatus.premium_tokens_limit)} tokens + {formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens </span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span> </div> From 6b06416d4761007cd6a4551313d7038cfef52cc7 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:19 +0530 Subject: [PATCH 289/299] feat(sidebar): live pages meter via Zero for authenticated users --- .../layout/providers/LayoutDataProvider.tsx | 9 --------- .../ui/sidebar/AuthenticatedPageUsageDisplay.tsx | 15 +++++++++++++++ .../components/layout/ui/sidebar/Sidebar.tsx | 6 ++---- 3 files changed, 17 insertions(+), 13 deletions(-) create mode 100644 surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index afd888f48..d70a7ade4 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -681,14 +681,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid } }, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]); - // Page usage - const pageUsage = user - ? { - pagesUsed: user.pages_used, - pagesLimit: user.pages_limit, - } - : undefined; - // Detect if we're on the chat page (needs overflow-hidden for chat's own scroll) const isChatPage = pathname?.includes("/new-chat") ?? false; @@ -723,7 +715,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid onManageMembers={handleManageMembers} onUserSettings={handleUserSettings} onLogout={handleLogout} - pageUsage={pageUsage} theme={theme} setTheme={setTheme} isChatPage={isChatPage} diff --git a/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx new file mode 100644 index 000000000..ad31d50bb --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx @@ -0,0 +1,15 @@ +"use client"; + +import { useQuery } from "@rocicorp/zero/react"; +import { useIsAnonymous } from "@/contexts/anonymous-mode"; +import { queries } from "@/zero/queries"; +import { PageUsageDisplay } from "./PageUsageDisplay"; + +export function AuthenticatedPageUsageDisplay() { + const isAnonymous = useIsAnonymous(); + const [me] = useQuery(queries.user.me({})); + + if (isAnonymous || !me) return null; + + return <PageUsageDisplay pagesUsed={me.pagesUsed} pagesLimit={me.pagesLimit} />; +} diff --git a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx index adad52792..d5038ea05 100644 --- a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx @@ -12,9 +12,9 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode"; import { cn } from "@/lib/utils"; import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize"; import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types"; +import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay"; import { ChatListItem } from "./ChatListItem"; import { NavSection } from "./NavSection"; -import { PageUsageDisplay } from "./PageUsageDisplay"; import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay"; import { SidebarButton } from "./SidebarButton"; import { SidebarCollapseButton } from "./SidebarCollapseButton"; @@ -338,9 +338,7 @@ function SidebarUsageFooter({ return ( <div className="px-3 py-3 border-t space-y-3"> <PremiumTokenUsageDisplay /> - {pageUsage && ( - <PageUsageDisplay pagesUsed={pageUsage.pagesUsed} pagesLimit={pageUsage.pagesLimit} /> - )} + <AuthenticatedPageUsageDisplay /> <div className="space-y-0.5"> <Link href={`/dashboard/${searchSpaceId}/more-pages`} From 38a4742ec688d741da301f31f6657e731f3d3033 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:37 +0530 Subject: [PATCH 290/299] feat(settings): live buy-tokens meter via Zero --- .../settings/buy-tokens-content.tsx | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/surfsense_web/components/settings/buy-tokens-content.tsx b/surfsense_web/components/settings/buy-tokens-content.tsx index 649a50639..e7fac4255 100644 --- a/surfsense_web/components/settings/buy-tokens-content.tsx +++ b/surfsense_web/components/settings/buy-tokens-content.tsx @@ -1,5 +1,6 @@ "use client"; +import { useQuery as useZeroQuery } from "@rocicorp/zero/react"; import { useMutation, useQuery } from "@tanstack/react-query"; import { Minus, Plus } from "lucide-react"; import { useParams } from "next/navigation"; @@ -11,6 +12,7 @@ import { Spinner } from "@/components/ui/spinner"; import { stripeApiService } from "@/lib/apis/stripe-api.service"; import { AppError } from "@/lib/error"; import { cn } from "@/lib/utils"; +import { queries } from "@/zero/queries"; const TOKEN_PACK_SIZE = 1_000_000; const PRICE_PER_PACK_USD = 1; @@ -21,11 +23,15 @@ export function BuyTokensContent() { const searchSpaceId = Number(params?.search_space_id); const [quantity, setQuantity] = useState(1); + // Server config flag: stays on REST, not per-user. const { data: tokenStatus } = useQuery({ queryKey: ["token-status"], queryFn: () => stripeApiService.getTokenStatus(), }); + // Live per-user usage via Zero. + const [me] = useZeroQuery(queries.user.me({})); + const purchaseMutation = useMutation({ mutationFn: stripeApiService.createTokenCheckoutSession, onSuccess: (response) => { @@ -54,12 +60,11 @@ export function BuyTokensContent() { ); } - const usagePercentage = tokenStatus - ? Math.min( - (tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100, - 100 - ) - : 0; + const used = me?.premiumTokensUsed ?? 0; + const limit = me?.premiumTokensLimit ?? 0; + // Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)). + const remaining = Math.max(0, limit - used); + const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0; return ( <div className="w-full space-y-5"> @@ -68,18 +73,17 @@ export function BuyTokensContent() { <p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p> </div> - {tokenStatus && ( + {me && ( <div className="rounded-lg border bg-muted/20 p-3 space-y-1.5"> <div className="flex justify-between items-center text-xs"> <span className="text-muted-foreground"> - {tokenStatus.premium_tokens_used.toLocaleString()} /{" "} - {tokenStatus.premium_tokens_limit.toLocaleString()} premium tokens + {used.toLocaleString()} / {limit.toLocaleString()} premium tokens </span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span> </div> <Progress value={usagePercentage} className="h-1.5" /> <p className="text-[11px] text-muted-foreground"> - {tokenStatus.premium_tokens_remaining.toLocaleString()} tokens remaining + {remaining.toLocaleString()} tokens remaining </p> </div> )} From b9b4d0b3777667bb6aa59dacc6120d8ae8eb2783 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:58 +0530 Subject: [PATCH 291/299] chore(usage): stop polling /users/me and token-status for live fields --- .../[search_space_id]/purchase-success/page.tsx | 9 --------- surfsense_web/atoms/user/user-query.atoms.ts | 5 ++++- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx index 67d9edab0..85bc4aaa6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx @@ -1,11 +1,8 @@ "use client"; -import { useQueryClient } from "@tanstack/react-query"; import { CheckCircle2 } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; -import { useEffect } from "react"; -import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; import { Button } from "@/components/ui/button"; import { Card, @@ -18,14 +15,8 @@ import { export default function PurchaseSuccessPage() { const params = useParams(); - const queryClient = useQueryClient(); const searchSpaceId = String(params.search_space_id ?? ""); - useEffect(() => { - void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY }); - void queryClient.invalidateQueries({ queryKey: ["token-status"] }); - }, [queryClient]); - return ( <div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8"> <Card className="w-full max-w-lg"> diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index 8e196c9c7..a59811324 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe(); export const currentUserAtom = atomWithQuery(() => { return { queryKey: USER_QUERY_KEY, - staleTime: 5 * 60 * 1000, + // Live-changing numeric fields (pages_*, premium_tokens_*) are now + // pushed via Zero (queries.user.me()), so /users/me only needs to + // fire once per session for the static profile fields. + staleTime: Infinity, enabled: !!getBearerToken(), queryFn: userQueryFn, }; From cd25175b8459994b7dc982be1de5eb22b5bb7d32 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:36:13 +0530 Subject: [PATCH 292/299] chore: ran linting --- .../139_add_user_to_zero_publication.py | 4 +- surfsense_backend/app/config/__init__.py | 4 +- .../app/services/auto_model_pin_service.py | 4 +- .../openrouter_integration_service.py | 26 +++---------- .../app/services/quality_score.py | 10 ++--- .../app/tasks/chat/stream_new_chat.py | 24 ++++++++---- .../services/test_auto_model_pin_service.py | 2 +- .../services/test_llm_router_pool_filter.py | 37 ++++++++++++------- .../test_openrouter_integration_service.py | 2 - .../services/test_openrouter_legacy_config.py | 4 +- .../tests/unit/services/test_quality_score.py | 9 +++-- .../unit/test_stream_new_chat_contract.py | 8 ++-- .../components/assistant-ui/markdown-text.tsx | 14 +++---- .../lib/apis/documents-api.service.ts | 12 +++--- 14 files changed, 78 insertions(+), 82 deletions(-) diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py index 5b8bc29b0..83c96a429 100644 --- a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py +++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py @@ -90,7 +90,9 @@ def _has_zero_version(conn, table: str) -> bool: ) -def _build_publication_ddl(documents_has_zero_ver: bool, user_has_zero_ver: bool) -> str: +def _build_publication_ddl( + documents_has_zero_ver: bool, user_has_zero_ver: bool +) -> str: doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else []) doc_col_list = ", ".join(doc_cols) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index b3eff571e..675b05d2c 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -286,9 +286,7 @@ def initialize_openrouter_integration(): if new_configs: config.GLOBAL_LLM_CONFIGS.extend(new_configs) - free_count = sum( - 1 for c in new_configs if c.get("billing_tier") == "free" - ) + free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free") premium_count = sum( 1 for c in new_configs if c.get("billing_tier") == "premium" ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index b2acd6f56..3a2c681b7 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -277,9 +277,7 @@ async def resolve_or_get_pinned_llm_config_id( c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids ] if not candidates: - raise ValueError( - "No usable global LLM configs are available for Auto mode" - ) + raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 67dbb6690..7e856d015 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -405,9 +405,7 @@ class OpenRouterIntegrationService: # Re-blend health scores against the freshly fetched catalogue. Also # re-stamps health for any YAML-curated cfg with provider==OPENROUTER # so a hand-picked dead OR model is gated like a dynamic one. - await self._enrich_health_safely( - static_configs + new_configs, log_summary=True - ) + await self._enrich_health_safely(static_configs + new_configs, log_summary=True) # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay @@ -415,8 +413,8 @@ class OpenRouterIntegrationService: # reset cached context-window profiles). try: from app.config import config as _app_config - from app.services.llm_router_service import LLMRouterService from app.services.llm_router_service import ( + LLMRouterService, _router_instance_cache as _chat_router_cache, ) @@ -426,9 +424,7 @@ class OpenRouterIntegrationService: ) _chat_router_cache.clear() except Exception as exc: - logger.warning( - "OpenRouter refresh: router rebuild skipped (%s)", exc - ) + logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc) @staticmethod def _tier_counts(configs: list[dict]) -> dict[str, int]: @@ -475,19 +471,11 @@ class OpenRouterIntegrationService: return premium_pool = sorted( - [ - c - for c in or_cfgs - if str(c.get("billing_tier", "")).lower() == "premium" - ], + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"], key=lambda c: -int(c.get("quality_score_static") or 0), )[:_HEALTH_ENRICH_TOP_N_PREMIUM] free_pool = sorted( - [ - c - for c in or_cfgs - if str(c.get("billing_tier", "")).lower() == "free" - ], + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"], key=lambda c: -int(c.get("quality_score_static") or 0), )[:_HEALTH_ENRICH_TOP_N_FREE] # De-duplicate while preserving order: a cfg shouldn't fall in both @@ -507,9 +495,7 @@ class OpenRouterIntegrationService: api_key = str(self._settings.get("api_key") or "") semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY) - async with httpx.AsyncClient( - timeout=_HEALTH_FETCH_TIMEOUT_SEC - ) as client: + async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client: results = await asyncio.gather( *( self._fetch_endpoints(client, semaphore, api_key, cfg) diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 8f6c75d56..2fb37de21 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -7,12 +7,12 @@ sort and a SHA256 pick. Score components (0-100 scale, higher is better): -* ``static_score_or`` – derived from the bulk ``/api/v1/models`` payload +* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload (provider prestige + ``created`` recency + pricing band + context window + capabilities + narrow tiny/legacy slug penalty). -* ``static_score_yaml`` – same shape for hand-curated YAML configs, plus +* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus an operator-trust bonus (the operator deliberately picked this model). -* ``aggregate_health`` – run on per-model ``/api/v1/models/{id}/endpoints`` +* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints`` responses; returns ``(gated, score_or_none)``. The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in @@ -281,9 +281,7 @@ def static_score_yaml(cfg: dict) -> int: model_name = cfg.get("model_name") or "" litellm_params = cfg.get("litellm_params") or {} lookup_name = ( - litellm_params.get("base_model") - or litellm_params.get("model") - or model_name + litellm_params.get("base_model") or litellm_params.get("model") or model_name ) ctx = 0 diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 53f237f06..dbfe9a67b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1814,7 +1814,9 @@ async def _stream_agent_events( resolved_path = _extract_resolved_file_path( tool_name=tool_name, tool_output=tool_output, - tool_input={"file_path": staged_file_path} if staged_file_path else None, + tool_input={"file_path": staged_file_path} + if staged_file_path + else None, ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): @@ -2441,8 +2443,7 @@ async def stream_new_chat( await _preflight_llm(llm) mark_healthy(llm_config_id) _perf_log.info( - "[stream_new_chat] auto_pin_preflight ok config_id=%s " - "took=%.3fs", + "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs", llm_config_id, time.perf_counter() - _t_preflight, ) @@ -2891,7 +2892,11 @@ async def stream_new_chat( # Inject title update mid-stream as soon as the background # task finishes. - if title_task is not None and title_task.done() and not title_emitted: + if ( + title_task is not None + and title_task.done() + and not title_emitted + ): generated_title, title_usage = title_task.result() if title_usage: accumulator.add(**title_usage) @@ -2944,7 +2949,9 @@ async def stream_new_chat( ) ).resolved_llm_config_id - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: raise stream_exc @@ -3480,8 +3487,7 @@ async def stream_resume_chat( await _preflight_llm(llm) mark_healthy(llm_config_id) _perf_log.info( - "[stream_resume] auto_pin_preflight ok config_id=%s " - "took=%.3fs", + "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs", llm_config_id, time.perf_counter() - _t_preflight, ) @@ -3684,7 +3690,9 @@ async def stream_resume_chat( ) ).resolved_llm_config_id - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: raise stream_exc diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d333f0b7a..49b3621c7 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -574,7 +574,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", - high_score_cfgs + [low_score_trap], + [*high_score_cfgs, low_score_trap], ) async def _allowed(*_args, **_kwargs): diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index 0191025ec..c309ff881 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -96,9 +96,12 @@ def test_router_pool_includes_or_premium_excludes_or_free(): ), ] - with patch("app.services.llm_router_service.Router") as mock_router, patch( - "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" - ) as mock_ctx_fb: + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): mock_ctx_fb.side_effect = lambda ml: (ml, None) mock_router.return_value = object() LLMRouterService.initialize(configs) @@ -124,9 +127,10 @@ def test_router_pool_includes_or_premium_excludes_or_free(): assert "openrouter/openai/gpt-4o" in prem assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True # Dynamic OR free never enters the pool, so it's never counted as premium. - assert LLMRouterService.is_premium_model( - "openrouter/meta-llama/llama-3.3-70b:free" - ) is False + assert ( + LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free") + is False + ) def test_router_pool_filter_mechanics_respect_override(): @@ -147,9 +151,12 @@ def test_router_pool_filter_mechanics_respect_override(): ), ] - with patch("app.services.llm_router_service.Router") as mock_router, patch( - "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" - ) as mock_ctx_fb: + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): mock_ctx_fb.side_effect = lambda ml: (ml, None) mock_router.return_value = object() LLMRouterService.initialize(configs) @@ -167,13 +174,17 @@ def test_rebuild_refreshes_pool_after_configs_change(): configs_v1 = [ _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), ] - configs_v2 = configs_v1 + [ + configs_v2 = [ + *configs_v1, _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), ] - with patch("app.services.llm_router_service.Router") as mock_router, patch( - "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" - ) as mock_ctx_fb: + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): mock_ctx_fb.side_effect = lambda ml: (ml, None) mock_router.return_value = object() diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index d3921729d..085740032 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -214,5 +214,3 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/gpt-4o" in model_names assert "openai/dall-e" not in model_names assert "openai/completion-only" not in model_names - - diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py index b3dd2bf18..4eb1f2295 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py @@ -68,9 +68,7 @@ openrouter_integration: assert "deprecated" in captured -def test_new_keys_take_priority_over_legacy_back_compat( - monkeypatch, tmp_path, capsys -): +def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys): """If both legacy and new keys are present, new keys win (setdefault).""" _write_yaml( tmp_path, diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index fbc91521d..6fbc8fd62 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -106,9 +106,12 @@ def test_context_signal_bands(ctx, expected): def test_capabilities_signal_caps_at_five(): - assert capabilities_signal( - ["tools", "structured_outputs", "reasoning", "include_reasoning"] - ) <= 5 + assert ( + capabilities_signal( + ["tools", "structured_outputs", "reasoning", "include_reasoning"] + ) + <= 5 + ) def test_capabilities_signal_tools_only(): diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 3676601f4..910009667 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -13,8 +13,8 @@ from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, _contract_enforcement_active, - _extract_resolved_file_path, _evaluate_file_contract_outcome, + _extract_resolved_file_path, _log_chat_stream_error, _tool_output_has_error, ) @@ -222,7 +222,7 @@ async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkey from app.tasks.chat.stream_new_chat import _preflight_llm - class _RateLimitedExc(Exception): + class _RateLimitedError(Exception): """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers.""" rate_calls: list[dict] = [] @@ -230,7 +230,7 @@ async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkey async def _fake_acompletion_429(**kwargs): rate_calls.append(kwargs) - raise _RateLimitedExc("simulated 429") + raise _RateLimitedError("simulated 429") async def _fake_acompletion_other(**kwargs): other_calls.append(kwargs) @@ -245,7 +245,7 @@ async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkey import litellm # type: ignore[import-not-found] monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429) - with pytest.raises(_RateLimitedExc): + with pytest.raises(_RateLimitedError): await _preflight_llm(fake_llm) assert len(rate_calls) == 1 assert rate_calls[0]["max_tokens"] == 1 diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index bfbc3a423..9fddec360 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -19,6 +19,7 @@ import remarkMath from "remark-math"; import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; +import { toast } from "sonner"; import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; import { @@ -33,7 +34,6 @@ import { useElectronAPI } from "@/hooks/use-platform"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; -import { toast } from "sonner"; function MarkdownCodeBlockSkeleton() { return ( @@ -207,13 +207,7 @@ function isStandaloneDocumentsPathText(node: ReactNode): string | null { return value; } -function FilePathLink({ - path, - className, -}: { - path: string; - className?: string; -}) { +function FilePathLink({ path, className }: { path: string; className?: string }) { const openEditorPanel = useSetAtom(openEditorPanelAtom); const params = useParams(); const electronAPI = useElectronAPI(); @@ -221,7 +215,9 @@ function FilePathLink({ const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) ? Number(searchSpaceIdParam[0]) : Number(searchSpaceIdParam); - const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) ? parsedSearchSpaceId : undefined; + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) + ? parsedSearchSpaceId + : undefined; return ( <button diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts index 949e3b29f..630c88d16 100644 --- a/surfsense_web/lib/apis/documents-api.service.ts +++ b/surfsense_web/lib/apis/documents-api.service.ts @@ -5,6 +5,7 @@ import { type DeleteDocumentRequest, deleteDocumentRequest, deleteDocumentResponse, + documentTitleRead, type GetDocumentByChunkRequest, type GetDocumentChunksRequest, type GetDocumentRequest, @@ -28,7 +29,6 @@ import { getSurfsenseDocsRequest, getSurfsenseDocsResponse, type SearchDocumentsRequest, - documentTitleRead, type SearchDocumentTitlesRequest, searchDocumentsRequest, searchDocumentsResponse, @@ -270,15 +270,15 @@ class DocumentsApiService { ); }; - getDocumentByVirtualPath = async (request: { - search_space_id: number; - virtual_path: string; - }) => { + getDocumentByVirtualPath = async (request: { search_space_id: number; virtual_path: string }) => { const params = new URLSearchParams({ search_space_id: String(request.search_space_id), virtual_path: request.virtual_path, }); - return baseApiService.get(`/api/v1/documents/by-virtual-path?${params.toString()}`, documentTitleRead); + return baseApiService.get( + `/api/v1/documents/by-virtual-path?${params.toString()}`, + documentTitleRead + ); }; /** From ae9d36d77f26c4b74c65ed309fc0204dd4552b36 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 14:34:23 -0700 Subject: [PATCH 293/299] feat: unified credits and its cost calculations --- docker/.env.example | 28 +- surfsense_backend/.env.example | 42 +- .../140_premium_tokens_to_credit_micros.py | 291 ++++++++++ surfsense_backend/app/app.py | 2 + surfsense_backend/app/celery_app.py | 2 + surfsense_backend/app/config/__init__.py | 146 ++++- .../app/config/global_llm_config.example.yaml | 29 + surfsense_backend/app/db.py | 33 +- .../app/etl_pipeline/etl_pipeline_service.py | 25 +- .../app/routes/image_generation_routes.py | 154 +++++- .../app/routes/new_chat_routes.py | 7 +- .../app/routes/search_spaces_routes.py | 4 + surfsense_backend/app/routes/stripe_routes.py | 57 +- .../app/routes/vision_llm_routes.py | 7 + .../app/schemas/image_generation.py | 18 + surfsense_backend/app/schemas/new_chat.py | 1 + surfsense_backend/app/schemas/stripe.py | 22 +- surfsense_backend/app/schemas/vision_llm.py | 32 ++ .../app/services/billable_calls.py | 430 +++++++++++++++ .../app/services/llm_router_service.py | 58 +- surfsense_backend/app/services/llm_service.py | 30 +- .../openrouter_integration_service.py | 319 +++++++++++ .../app/services/pricing_registration.py | 274 ++++++++++ .../app/services/provider_api_base.py | 107 ++++ .../app/services/quota_checked_vision_llm.py | 105 ++++ .../app/services/token_quota_service.py | 125 ++++- .../app/services/token_tracking_service.py | 239 +++++++- .../app/services/vision_llm_router_service.py | 16 +- .../app/tasks/celery_tasks/podcast_tasks.py | 67 ++- .../celery_tasks/video_presentation_tasks.py | 68 ++- .../app/tasks/chat/stream_new_chat.py | 112 ++-- .../tests/unit/routes/test_image_gen_quota.py | 138 +++++ .../services/test_agent_billing_resolver.py | 436 +++++++++++++++ .../tests/unit/services/test_billable_call.py | 432 +++++++++++++++ .../test_openrouter_integration_service.py | 156 ++++++ .../services/test_pricing_registration.py | 447 +++++++++++++++ .../services/test_quota_checked_vision_llm.py | 157 ++++++ .../services/test_token_quota_service_cost.py | 515 ++++++++++++++++++ .../tests/unit/tasks/test_podcast_billing.py | 325 +++++++++++ .../tasks/test_video_presentation_billing.py | 330 +++++++++++ surfsense_web/app/(home)/free/page.tsx | 4 +- surfsense_web/app/(home)/pricing/page.tsx | 2 +- .../[search_space_id]/buy-more/page.tsx | 2 +- .../components/PurchaseHistoryContent.tsx | 27 +- surfsense_web/atoms/user/user-query.atoms.ts | 6 +- .../assistant-ui/assistant-message.tsx | 20 + .../assistant-ui/token-usage-context.tsx | 21 +- .../free-chat/quota-warning-banner.tsx | 4 +- .../ui/sidebar/PremiumTokenUsageDisplay.tsx | 25 +- .../components/pricing/pricing-section.tsx | 34 +- .../settings/buy-tokens-content.tsx | 47 +- .../settings/image-model-manager.tsx | 20 +- .../components/settings/llm-role-manager.tsx | 27 +- .../settings/vision-model-manager.tsx | 20 +- surfsense_web/contexts/login-gate.tsx | 4 +- .../contracts/types/new-llm-config.types.ts | 6 + surfsense_web/contracts/types/stripe.types.ts | 15 +- .../lib/chat/chat-error-classifier.ts | 2 +- surfsense_web/lib/chat/streaming-state.ts | 9 +- surfsense_web/lib/chat/thread-persistence.ts | 13 +- surfsense_web/zero/schema/user.ts | 13 +- 61 files changed, 5835 insertions(+), 272 deletions(-) create mode 100644 surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py create mode 100644 surfsense_backend/app/services/billable_calls.py create mode 100644 surfsense_backend/app/services/pricing_registration.py create mode 100644 surfsense_backend/app/services/provider_api_base.py create mode 100644 surfsense_backend/app/services/quota_checked_vision_llm.py create mode 100644 surfsense_backend/tests/unit/routes/test_image_gen_quota.py create mode 100644 surfsense_backend/tests/unit/services/test_agent_billing_resolver.py create mode 100644 surfsense_backend/tests/unit/services/test_billable_call.py create mode 100644 surfsense_backend/tests/unit/services/test_pricing_registration.py create mode 100644 surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py create mode 100644 surfsense_backend/tests/unit/services/test_token_quota_service_cost.py create mode 100644 surfsense_backend/tests/unit/tasks/test_podcast_billing.py create mode 100644 surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py diff --git a/docker/.env.example b/docker/.env.example index 95de0cf85..c2e87a619 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_BATCH_SIZE=100 -# Premium token purchases ($1 per 1M tokens for premium-tier models) +# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of +# credit; premium turns debit the actual per-call provider cost +# reported by LiteLLM, so cheap and expensive models bill proportionally) # STRIPE_TOKEN_BUYING_ENABLED=FALSE # STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -# STRIPE_TOKENS_PER_UNIT=1000000 +# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000 # ------------------------------------------------------------------------------ # TTS & STT (Text-to-Speech / Speech-to-Text) @@ -315,9 +318,24 @@ STT_SERVICE=local/base # Pages limit per user for ETL (default: unlimited) # PAGES_LIMIT=500 -# Premium token quota per registered user (default: 5M) -# Only applies to models with billing_tier=premium in global_llm_config.yaml -# PREMIUM_TOKEN_LIMIT=5000000 +# Premium credit quota per registered user, in micro-USD (default: $5). +# Premium turns are debited at the actual per-call provider cost reported +# by LiteLLM. Only applies to models with billing_tier=premium. +# PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default). +# QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default). +# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation for the podcast Celery task ($0.20 default). +# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation for the video Celery task ($1.00 default). +# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — public users can chat without an account # Set TRUE to enable /free pages and anonymous chat API diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index a793f33d1..1b1478ae6 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000 # Set FALSE to disable new checkout session creation temporarily STRIPE_PAGE_BUYING_ENABLED=TRUE -# Premium token purchases via Stripe (for premium-tier model usage) -# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens) +# Premium credit purchases via Stripe (for premium-tier model usage). +# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit +# (default 1_000_000 = $1.00). Premium turns are billed at the actual +# per-call provider cost reported by LiteLLM. STRIPE_TOKEN_BUYING_ENABLED=FALSE STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... -STRIPE_TOKENS_PER_UNIT=1000000 +STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping): +# STRIPE_TOKENS_PER_UNIT=1000000 # Periodic Stripe safety net for purchases left in PENDING (minutes old) STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 @@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300 # (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) PAGES_LIMIT=500 -# Premium token quota per registered user (default: 3,000,000) -# Applies only to models with billing_tier=premium in global_llm_config.yaml -PREMIUM_TOKEN_LIMIT=3000000 +# Premium credit quota per registered user, in micro-USD +# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the +# actual per-call provider cost reported by LiteLLM, so cheap and expensive +# models bill proportionally. Applies only to models with +# billing_tier=premium in global_llm_config.yaml. +PREMIUM_CREDIT_MICROS_LIMIT=5000000 +# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping): +# PREMIUM_TOKEN_LIMIT=5000000 + +# Safety ceiling on per-call premium reservation, in micro-USD. +# stream_new_chat estimates an upper-bound cost from the model's +# litellm-published per-token rates × the config's quota_reserve_tokens +# and clamps to this value so a misconfigured model can't lock the +# user's whole balance on one call. Default $1.00. +QUOTA_MAX_RESERVE_MICROS=1000000 + +# Per-image reservation (in micro-USD) for the POST /image-generations +# endpoint. Bypassed for free configs. Default $0.05. +QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000 + +# Per-podcast reservation (in micro-USD) used by the podcast Celery task. +# Single envelope covers one transcript-generation LLM call. Default $0.20. +QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000 + +# Per-video-presentation reservation (in micro-USD) used by the video +# presentation Celery task. Covers worst-case fan-out of N slide-scene +# generations + refines. Default $1.00. NOTE: tasks using the override +# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care. +QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000 # No-login (anonymous) mode — allows public users to chat without an account # Set TRUE to enable /free pages and anonymous chat API diff --git a/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py new file mode 100644 index 000000000..64aa699e8 --- /dev/null +++ b/surfsense_backend/alembic/versions/140_premium_tokens_to_credit_micros.py @@ -0,0 +1,291 @@ +"""rename premium token columns to credit micros and add cost_micros to token_usage + +Migrates the premium quota system from a flat token counter to a USD-cost +based credit system, where 1 credit = 1 micro-USD ($0.000001). + +Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe +price means every existing value is already correct in the new unit, no +data transformation needed): + + user.premium_tokens_limit -> premium_credit_micros_limit + user.premium_tokens_used -> premium_credit_micros_used + user.premium_tokens_reserved -> premium_credit_micros_reserved + + premium_token_purchases.tokens_granted -> credit_micros_granted + +New column for cost auditing per turn: + + token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0) + +The "user" table is in zero_publication's column list (added in 139), so +this migration must drop and recreate the publication with the renamed +column names, otherwise zero-cache will replicate stale column names and +the FE Zero schema will fail to bind. + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on +"user". Skipping the data-volume reset will leave IndexedDB clients seeing +column-not-found errors from a stale catalog snapshot. + +Revision ID: 140 +Revises: 139 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "140" +down_revision: str | None = "139" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Replicates 139's document column list verbatim — must stay in sync. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Same five live-meter fields as 139, with the renamed column names. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_credit_micros_limit", + "premium_credit_micros_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _column_exists(conn, table: str, column: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = :col" + ), + {"tbl": table, "col": column}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl( + user_cols: list[str], + *, + documents_has_zero_ver: bool, + user_has_zero_ver: bool, +) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_col_list_with_meta = user_cols + ( + ['"_0_version"'] if user_has_zero_ver else [] + ) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_col_list_with_meta) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def upgrade() -> None: + conn = op.get_bind() + + # ------------------------------------------------------------------ + # 1. Add cost_micros to token_usage. Idempotent guard so re-runs in + # dev environments are safe. + # ------------------------------------------------------------------ + if not _column_exists(conn, "token_usage", "cost_micros"): + op.add_column( + "token_usage", + sa.Column( + "cost_micros", + sa.BigInteger(), + nullable=False, + server_default="0", + ), + ) + + # ------------------------------------------------------------------ + # 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted. + # ------------------------------------------------------------------ + if _column_exists( + conn, "premium_token_purchases", "tokens_granted" + ) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"): + op.alter_column( + "premium_token_purchases", + "tokens_granted", + new_column_name="credit_micros_granted", + ) + + # ------------------------------------------------------------------ + # 3. Rename user.premium_tokens_* -> premium_credit_micros_*. + # + # We must drop the publication first (it references the old column + # names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE + # in a transaction block; alembic's outer transaction already holds + # one, but a SAVEPOINT keeps the LOCK + DDL atomic. + # ------------------------------------------------------------------ + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Re-assert REPLICA IDENTITY DEFAULT for safety; column-list + # publications require at least the PK to be in the column list, + # which is true for both the old and new shape. + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + # Drop the publication BEFORE renaming columns, otherwise Postgres + # rejects the rename: "cannot drop column ... referenced by + # publication". + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for old, new in ( + ("premium_tokens_limit", "premium_credit_micros_limit"), + ("premium_tokens_used", "premium_credit_micros_used"), + ("premium_tokens_reserved", "premium_credit_micros_reserved"), + ): + if _column_exists(conn, "user", old) and not _column_exists( + conn, "user", new + ): + op.alter_column("user", old, new_column_name=new) + + # Update the server_default on the renamed limit column so newly + # inserted users get $5 of credit (== 5_000_000 micros) by + # default. Existing rows are unaffected. + op.alter_column( + "user", + "premium_credit_micros_limit", + server_default="5000000", + ) + + # Recreate the publication with the new column names. + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + USER_COLS, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + +def downgrade() -> None: + """Revert the rename and drop ``cost_micros``. + + Mirrors ``upgrade``: drop the publication, rename columns back, drop + the new column, recreate the publication with the old column list. + Same zero-cache stop/reset runbook applies in reverse. + """ + conn = op.get_bind() + + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + for new, old in ( + ("premium_credit_micros_limit", "premium_tokens_limit"), + ("premium_credit_micros_used", "premium_tokens_used"), + ("premium_credit_micros_reserved", "premium_tokens_reserved"), + ): + if _column_exists(conn, "user", new) and not _column_exists( + conn, "user", old + ): + op.alter_column("user", new, new_column_name=old) + + op.alter_column( + "user", + "premium_tokens_limit", + server_default="5000000", + ) + + legacy_user_cols = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", + ] + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + conn.execute( + sa.text( + _build_publication_ddl( + legacy_user_cols, + documents_has_zero_ver=documents_has_zero_ver, + user_has_zero_ver=user_has_zero_ver, + ) + ) + ) + + if _column_exists( + conn, "premium_token_purchases", "credit_micros_granted" + ) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"): + op.alter_column( + "premium_token_purchases", + "credit_micros_granted", + new_column_name="tokens_granted", + ) + + if _column_exists(conn, "token_usage", "cost_micros"): + op.drop_column("token_usage", "cost_micros") diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 016c2de42..14d7f4d23 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -31,6 +31,7 @@ from app.config import ( initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session @@ -432,6 +433,7 @@ async def lifespan(app: FastAPI): await setup_checkpointer_tables() initialize_openrouter_integration() _start_openrouter_background_refresh() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 58a8b0f39..74710d5e1 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -22,10 +22,12 @@ def init_worker(**kwargs): initialize_image_gen_router, initialize_llm_router, initialize_openrouter_integration, + initialize_pricing_registration, initialize_vision_llm_router, ) initialize_openrouter_integration() + initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 675b05d2c..2aeeafb34 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -138,7 +138,11 @@ def load_global_image_gen_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_image_generation_configs", []) + configs = data.get("global_image_generation_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global image generation configs: {e}") return [] @@ -153,7 +157,11 @@ def load_global_vision_llm_configs(): try: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) - return data.get("global_vision_llm_configs", []) + configs = data.get("global_vision_llm_configs", []) or [] + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs except Exception as e: print(f"Warning: Failed to load global vision LLM configs: {e}") return [] @@ -254,6 +262,15 @@ def load_openrouter_integration_settings() -> dict | None: "anonymous_enabled_free", settings["anonymous_enabled"] ) + # Image generation + vision LLM emission are opt-in (issue L). + # OpenRouter's catalogue contains hundreds of image / vision + # capable models; auto-injecting all of them into every + # deployment would explode the model selector and surprise + # operators upgrading from prior versions. Default to False so + # admins must explicitly turn them on. + settings.setdefault("image_generation_enabled", False) + settings.setdefault("vision_enabled", False) + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") @@ -296,10 +313,60 @@ def initialize_openrouter_integration(): ) else: print("Info: OpenRouter integration enabled but no models fetched") + + # Image generation + vision LLM emissions are opt-in (issue L). + # Both reuse the catalogue already cached by ``service.initialize`` + # so we don't make additional network calls here. + if settings.get("image_generation_enabled"): + try: + image_configs = service.get_image_generation_configs() + if image_configs: + config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs) + print( + f"Info: OpenRouter integration added {len(image_configs)} " + f"image-generation models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") + + if settings.get("vision_enabled"): + try: + vision_configs = service.get_vision_llm_configs() + if vision_configs: + config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) + print( + f"Info: OpenRouter integration added {len(vision_configs)} " + f"vision LLM models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") +def initialize_pricing_registration(): + """ + Teach LiteLLM the per-token cost of every deployment in + ``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled + from the OpenRouter catalogue + any operator-declared YAML pricing). + + Must run AFTER ``initialize_openrouter_integration()`` so the + OpenRouter catalogue is populated and BEFORE the first LLM call so + ``response_cost`` is available in ``TokenTrackingCallback``. + + Failures are logged but never raised — startup must not be blocked + by a missing pricing entry; the worst-case is the model debits 0. + """ + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as e: + print(f"Warning: Failed to register LiteLLM pricing: {e}") + + def initialize_llm_router(): """ Initialize the LLM Router service for Auto mode. @@ -444,14 +511,54 @@ class Config: os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") ) - # Premium token quota settings - PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000")) + # Premium credit (micro-USD) quota settings. + # + # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy + # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are + # still honoured for one release as fall-back values — the prior + # $1-per-1M-tokens Stripe price means every existing value maps 1:1 + # to micros, so operators upgrading without changing their .env still + # get correct behaviour. A startup deprecation warning fires below if + # they're set. + PREMIUM_CREDIT_MICROS_LIMIT = int( + os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") + or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000") + ) STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") - STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")) + STRIPE_CREDIT_MICROS_PER_UNIT = int( + os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT") + or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000") + ) STRIPE_TOKEN_BUYING_ENABLED = ( os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" ) + # Safety ceiling on the per-call premium reservation. ``stream_new_chat`` + # estimates an upper-bound cost from ``litellm.get_model_info`` x the + # config's ``quota_reserve_tokens`` and clamps the result to this value + # so a misconfigured "$1000/M" model can't lock the user's whole balance + # on one call. Default $1.00 covers realistic worst-cases (Opus + 4K + # reserve_tokens ≈ $0.36) with headroom. + QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000")) + + if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv( + "PREMIUM_CREDIT_MICROS_LIMIT" + ): + print( + "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to " + "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the " + "current Stripe price). The old key will be removed in a " + "future release." + ) + if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv( + "STRIPE_CREDIT_MICROS_PER_UNIT" + ): + print( + "Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to " + "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). " + "The old key will be removed in a future release." + ) + # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000")) @@ -464,6 +571,35 @@ class Config: # Default quota reserve tokens when not specified per-model QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000")) + # Per-image reservation (in micro-USD) used by ``billable_call`` for the + # ``POST /image-generations`` endpoint when the global config does not + # override it. $0.05 covers realistic worst-cases for current OpenAI / + # OpenRouter image-gen pricing. Bypassed entirely for free configs. + QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") + ) + + # Per-podcast reservation (in micro-USD). One agent LLM call generating + # a transcript, typically 5k-20k completion tokens. $0.20 covers a long + # premium-model run. Tune via env. + QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000") + ) + + # Per-video-presentation reservation (in micro-USD). Fan-out of N + # slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``) + # plus refine retries; can produce many premium completions. $1.00 + # covers worst-case. Tune via env. + # + # NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of + # 1_000_000. The override path in ``billable_call`` bypasses the + # per-call clamp in ``estimate_call_reserve_micros``, so this is the + # *actual* hold — raising it via env is fine but means a single video + # task can lock $1+ of credit. + QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int( + os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000") + ) + # Abuse prevention: concurrent stream cap and CAPTCHA ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2")) ANON_CAPTCHA_REQUEST_THRESHOLD = int( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 79cbe1e51..d92640c8d 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -19,6 +19,24 @@ # Structure matches NewLLMConfig: # - Model configuration (provider, model_name, api_key, etc.) # - Prompt configuration (system_instructions, citations_enabled) +# +# COST-BASED PREMIUM CREDITS: +# Each premium config bills the user's USD-credit balance based on the +# actual provider cost reported by LiteLLM. For models LiteLLM already +# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything. +# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment) +# or any model LiteLLM doesn't have in its built-in pricing table, declare +# per-token costs inline so they bill correctly: +# +# litellm_params: +# base_model: "my-custom-azure-deploy" +# # USD per token; e.g. 0.000003 == $3.00 per million input tokens +# input_cost_per_token: 0.000003 +# output_cost_per_token: 0.000015 +# +# OpenRouter dynamic models pull pricing automatically from OpenRouter's +# API — no inline declaration needed. Models without resolvable pricing +# debit $0 from the user's balance and log a WARNING. # Router Settings for Auto Mode # These settings control how the LiteLLM Router distributes requests across models @@ -292,6 +310,17 @@ openrouter_integration: free_rpm: 20 free_tpm: 100000 + # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue + # contains hundreds of image- and vision-capable models; turning these on + # injects them into the global Image-Generation / Vision-LLM model + # selectors alongside any static configs. Tier (free/premium) is derived + # per model the same way it is for chat (`:free` suffix or zero pricing). + # When a user picks a premium image/vision model the call debits the + # shared $5 USD-cost-based premium credit pool — so leaving these off + # avoids surprise quota burn on existing deployments. Default: false. + image_generation_enabled: false + vision_enabled: false + litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 2fe478d9b..aef959ec9 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin): prompt_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0) + cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0") model_breakdown = Column(JSONB, nullable=True) call_details = Column(JSONB, nullable=True) @@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin): class PremiumTokenPurchase(Base, TimestampMixin): - """Tracks Stripe checkout sessions used to grant additional premium token credits.""" + """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units). + + Note: the table name is preserved (``premium_token_purchases``) for + operational continuity even though the unit is now USD micro-credits + instead of raw tokens. The ``credit_micros_granted`` column replaced + the legacy ``tokens_granted`` in migration 140; the stored values + were not transformed because the prior $1 = 1M tokens Stripe price + makes the unit conversion 1:1 numerically. + """ __tablename__ = "premium_token_purchases" __allow_unmapped__ = True @@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin): ) stripe_payment_intent_id = Column(String(255), nullable=True, index=True) quantity = Column(Integer, nullable=False) - tokens_granted = Column(BigInteger, nullable=False) + credit_micros_granted = Column(BigInteger, nullable=False) amount_total = Column(Integer, nullable=True) currency = Column(String(10), nullable=True) status = Column( @@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE": ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) @@ -2241,16 +2250,16 @@ else: ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") - premium_tokens_limit = Column( + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.PREMIUM_TOKEN_LIMIT, - server_default=str(config.PREMIUM_TOKEN_LIMIT), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - premium_tokens_used = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - premium_tokens_reserved = Column( + premium_credit_micros_reserved = Column( BigInteger, nullable=False, default=0, server_default="0" ) diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py index 4bb38b7b0..d45bd780c 100644 --- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py +++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py @@ -68,12 +68,25 @@ class EtlPipelineService: etl_service="VISION_LLM", content_type="image", ) - except Exception: - logging.warning( - "Vision LLM failed for %s, falling back to document parser", - request.filename, - exc_info=True, - ) + except Exception as exc: + # Special-case quota exhaustion so we log a clearer message + # — the vision LLM didn't "fail", the user just ran out of + # premium credit. Falling through to the document parser + # is a graceful degradation: OCR/Unstructured still + # extracts text from the image without burning credit. + from app.services.billable_calls import QuotaInsufficientError + + if isinstance(exc, QuotaInsufficientError): + logging.info( + "Vision LLM quota exhausted for %s; falling back to document parser", + request.filename, + ) + else: + logging.warning( + "Vision LLM failed for %s, falling back to document parser", + request.filename, + exc_info=True, + ) else: logging.info( "No vision LLM provided, falling back to document parser for %s", diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 97a3559b9..34ed80207 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -36,6 +36,11 @@ from app.schemas import ( ImageGenerationListRead, ImageGenerationRead, ) +from app.services.billable_calls import ( + DEFAULT_IMAGE_RESERVE_MICROS, + QuotaInsufficientError, + billable_call, +) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, @@ -92,6 +97,50 @@ def _build_model_string( return f"{prefix}/{model_name}" +async def _resolve_billing_for_image_gen( + session: AsyncSession, + config_id: int | None, + search_space: SearchSpace, +) -> tuple[str, str, int]: + """Resolve ``(billing_tier, base_model, reserve_micros)`` for a request. + + The resolution mirrors ``_execute_image_generation``'s lookup tree but + only extracts the fields needed for billing — we do this *before* + ``billable_call`` so the reservation is correctly sized for the + config that will actually run, and so we don't open an + ``ImageGeneration`` row for a request that's about to 402. + + User-owned (positive ID) BYOK configs are always free — they cost + the user nothing on our side. Auto mode currently treats as free + because the underlying router can dispatch to either premium or + free YAML configs and we don't surface the resolved deployment up + here yet. Bringing Auto under premium billing would require + threading the chosen deployment back from ``ImageGenRouterService``. + """ + resolved_id = config_id + if resolved_id is None: + resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + + if is_image_gen_auto_mode(resolved_id): + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + + if resolved_id < 0: + cfg = _get_global_image_gen_config(resolved_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + base_model = _build_model_string( + cfg.get("provider", ""), + cfg.get("model_name", ""), + cfg.get("custom_provider"), + ) + reserve_micros = int( + cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + ) + return (billing_tier, base_model, reserve_micros) + + # Positive ID = user-owned BYOK image-gen config — always free. + return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) + + async def _execute_image_generation( session: AsyncSession, image_gen: ImageGeneration, @@ -225,6 +274,9 @@ async def get_global_image_gen_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode currently treated as free until per-deployment + # billing-tier surfacing lands (see _resolve_billing_for_image_gen). + "billing_tier": "free", } ) @@ -241,6 +293,8 @@ async def get_global_image_gen_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + "quota_reserve_micros": cfg.get("quota_reserve_micros"), } ) @@ -454,7 +508,26 @@ async def create_image_generation( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create and execute an image generation request.""" + """Create and execute an image generation request. + + Premium configs are gated by the user's shared premium credit pool. + The flow is: + + 1. Permission check + load the search space (cheap, no provider call). + 2. Resolve which config will run so we know its billing tier and the + worst-case reservation size *before* opening any DB rows. + 3. Wrap the entire ImageGeneration row insert + provider call in + ``billable_call``. If quota is denied, ``billable_call`` raises + ``QuotaInsufficientError`` *before* we flush a row, which we + translate to HTTP 402 (no orphaned rows on the user's account, + no inserted error rows for "you ran out of credit"). + 4. On success, the actual ``response_cost`` flows through the + LiteLLM callback into the accumulator, and ``billable_call`` + finalizes the debit at exit. Inner ``try/except`` still catches + provider errors and stores them on ``error_message`` (HTTP 200 + with ``error_message`` set is preserved for failed-but-not-quota + scenarios — clients already know how to surface those). + """ try: await check_permission( session, @@ -471,33 +544,70 @@ async def create_image_generation( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - db_image_gen = ImageGeneration( - prompt=data.prompt, - model=data.model, - n=data.n, - quality=data.quality, - size=data.size, - style=data.style, - response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, - search_space_id=data.search_space_id, - created_by_id=user.id, + billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( + session, data.image_generation_config_id, search_space ) - session.add(db_image_gen) - await session.flush() - try: - await _execute_image_generation(session, db_image_gen, search_space) - except Exception as e: - logger.exception("Image generation call failed") - db_image_gen.error_message = str(e) + # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError + # propagates to the outer ``except QuotaInsufficientError`` handler + # below as HTTP 402 — it is intentionally NOT swallowed into + # ``error_message`` because that would (1) imply a successful row + # exists when none does, and (2) return HTTP 200 to a client + # whose request was actively *denied* (issue K). + async with billable_call( + user_id=search_space.user_id, + search_space_id=data.search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=reserve_micros, + usage_type="image_generation", + call_details={"model": base_model, "prompt": data.prompt[:100]}, + ): + db_image_gen = ImageGeneration( + prompt=data.prompt, + model=data.model, + n=data.n, + quality=data.quality, + size=data.size, + style=data.style, + response_format=data.response_format, + image_generation_config_id=data.image_generation_config_id, + search_space_id=data.search_space_id, + created_by_id=user.id, + ) + session.add(db_image_gen) + await session.flush() - await session.commit() - await session.refresh(db_image_gen) - return db_image_gen + try: + await _execute_image_generation(session, db_image_gen, search_space) + except Exception as e: + logger.exception("Image generation call failed") + db_image_gen.error_message = str(e) + + await session.commit() + await session.refresh(db_image_gen) + return db_image_gen except HTTPException: raise + except QuotaInsufficientError as exc: + # The user's premium credit pool is empty. No DB row is created + # because ``billable_call`` denies before yielding (issue K). + await session.rollback() + raise HTTPException( + status_code=402, + detail={ + "error_code": "premium_quota_exhausted", + "usage_type": exc.usage_type, + "used_micros": exc.used_micros, + "limit_micros": exc.limit_micros, + "remaining_micros": exc.remaining_micros, + "message": ( + "Out of premium credits for image generation. " + "Purchase additional credits or switch to a free model." + ), + }, + ) from exc except SQLAlchemyError: await session.rollback() raise HTTPException( diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 28b197ca2..d3bd51129 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1366,7 +1366,11 @@ async def append_message( # flush assigns the PK/defaults without a round-trip SELECT await session.flush() - # Persist token usage if provided (for assistant messages) + # Persist token usage if provided (for assistant messages). + # ``cost_micros`` is the provider USD cost reported by LiteLLM, + # forwarded by the FE through the appendMessage round-trip so + # the historical TokenUsage row matches the credit debit applied + # at finalize time. token_usage_data = raw_body.get("token_usage") if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: await record_token_usage( @@ -1377,6 +1381,7 @@ async def append_message( prompt_tokens=token_usage_data.get("prompt_tokens", 0), completion_tokens=token_usage_data.get("completion_tokens", 0), total_tokens=token_usage_data.get("total_tokens", 0), + cost_micros=token_usage_data.get("cost_micros", 0), model_breakdown=token_usage_data.get("usage"), call_details=token_usage_data.get("call_details"), thread_id=thread_id, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 72715ea5b..5ecfb1814 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -594,6 +594,7 @@ async def _get_image_gen_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -610,6 +611,7 @@ async def _get_image_gen_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None @@ -652,6 +654,7 @@ async def _get_vision_llm_config_by_id( "model_name": "auto", "is_global": True, "is_auto_mode": True, + "billing_tier": "free", } if config_id < 0: @@ -668,6 +671,7 @@ async def _get_vision_llm_config_by_id( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), } return None diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index cfdd4b52a..aed74ec8d 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase( metadata = _get_metadata(checkout_session) user_id = metadata.get("user_id") quantity = int(metadata.get("quantity", "0")) - tokens_per_unit = int(metadata.get("tokens_per_unit", "0")) + # Read the new metadata key first, fall back to the legacy one so + # in-flight checkout sessions created before the cost-credits + # release still fulfil correctly (the unit is numerically the + # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens). + credit_micros_per_unit = int( + metadata.get("credit_micros_per_unit") + or metadata.get("tokens_per_unit", "0") + ) - if not user_id or quantity <= 0 or tokens_per_unit <= 0: + if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: logger.error( "Skipping token fulfillment for session %s: incomplete metadata %s", checkout_session_id, @@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase( getattr(checkout_session, "payment_intent", None) ), quantity=quantity, - tokens_granted=quantity * tokens_per_unit, + credit_micros_granted=quantity * credit_micros_per_unit, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase( purchase.stripe_payment_intent_id = _normalize_optional_string( getattr(checkout_session, "payment_intent", None) ) - user.premium_tokens_limit = ( - max(user.premium_tokens_used, user.premium_tokens_limit) - + purchase.tokens_granted + # Top up the user's credit balance by the granted micro-USD amount. + # ``max(used, limit)`` clamps the case where the legacy code wrote a + # used value above the limit (e.g. underbilling rounding) so adding + # ``credit_micros_granted`` always lifts the limit by the full pack + # size rather than disappearing into past overuse. + user.premium_credit_micros_limit = ( + max(user.premium_credit_micros_used, user.premium_credit_micros_limit) + + purchase.credit_micros_granted ) await db_session.commit() @@ -532,12 +544,18 @@ async def create_token_checkout_session( user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), ): - """Create a Stripe Checkout Session for buying premium token packs.""" + """Create a Stripe Checkout Session for buying premium credit packs. + + Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of + credit (default 1_000_000 = $1.00). The user's balance is debited + at the actual provider cost reported by LiteLLM at finalize time, + so $1 of credit always buys $1 worth of provider usage at cost. + """ _ensure_token_buying_enabled() stripe_client = get_stripe_client() price_id = _get_required_token_price_id() success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) - tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT + credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( @@ -556,8 +574,8 @@ async def create_token_checkout_session( "metadata": { "user_id": str(user.id), "quantity": str(body.quantity), - "tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT), - "purchase_type": "premium_tokens", + "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), + "purchase_type": "premium_credit", }, } ) @@ -583,7 +601,7 @@ async def create_token_checkout_session( getattr(checkout_session, "payment_intent", None) ), quantity=body.quantity, - tokens_granted=tokens_granted, + credit_micros_granted=credit_micros_granted, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), status=PremiumTokenPurchaseStatus.PENDING, @@ -598,14 +616,19 @@ async def create_token_checkout_session( async def get_token_status( user: User = Depends(current_active_user), ): - """Return token-buying availability and current premium quota for frontend.""" - used = user.premium_tokens_used - limit = user.premium_tokens_limit + """Return token-buying availability and current premium credit quota for frontend. + + Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M + when displaying. The route name is preserved for back-compat with + pinned client deployments. + """ + used = user.premium_credit_micros_used + limit = user.premium_credit_micros_limit return TokenStripeStatusResponse( token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, - premium_tokens_used=used, - premium_tokens_limit=limit, - premium_tokens_remaining=max(0, limit - used), + premium_credit_micros_used=used, + premium_credit_micros_limit=limit, + premium_credit_micros_remaining=max(0, limit - used), ) diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 315c7c9fe..4f7e9f725 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -82,6 +82,9 @@ async def get_global_vision_llm_configs( "litellm_params": {}, "is_global": True, "is_auto_mode": True, + # Auto mode treated as free until per-deployment billing-tier + # surfacing lands; see ``get_vision_llm`` for parity. + "billing_tier": "free", } ) @@ -98,6 +101,10 @@ async def get_global_vision_llm_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "input_cost_per_token": cfg.get("input_cost_per_token"), + "output_cost_per_token": cfg.get("output_cost_per_token"), } ) diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 69f534e20..facca7b86 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel): Schema for reading global image generation configs from YAML. Global configs have negative IDs. API key is hidden. ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. """ id: int = Field( @@ -231,3 +237,15 @@ class GlobalImageGenConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + quota_reserve_micros: int | None = Field( + default=None, + description=( + "Optional override for the reservation amount (in micro-USD) used when " + "this image generation is premium. Falls back to " + "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." + ), + ) diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index ec5eefc07..892ff9693 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel): prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 + cost_micros: int = 0 model_breakdown: dict | None = None model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py index 3edd3e9e4..57265ec8e 100644 --- a/surfsense_backend/app/schemas/stripe.py +++ b/surfsense_backend/app/schemas/stripe.py @@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel): class TokenPurchaseRead(BaseModel): - """Serialized premium token purchase record.""" + """Serialized premium credit purchase record. + + ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The + schema name kept ``Token`` for API back-compat with pinned clients. + """ id: uuid.UUID stripe_checkout_session_id: str stripe_payment_intent_id: str | None = None quantity: int - tokens_granted: int + credit_micros_granted: int amount_total: int | None = None currency: str | None = None status: str @@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel): class TokenPurchaseHistoryResponse(BaseModel): - """Response containing the user's premium token purchases.""" + """Response containing the user's premium credit purchases.""" purchases: list[TokenPurchaseRead] class TokenStripeStatusResponse(BaseModel): - """Response describing token-buying availability and current quota.""" + """Response describing premium-credit-buying availability and balance. + + All ``premium_credit_micros_*`` fields are in micro-USD; the FE + divides by 1_000_000 to display USD. + """ token_buying_enabled: bool - premium_tokens_used: int = 0 - premium_tokens_limit: int = 0 - premium_tokens_remaining: int = 0 + premium_credit_micros_used: int = 0 + premium_credit_micros_limit: int = 0 + premium_credit_micros_remaining: int = 0 diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py index ab2e609dc..e55333a9d 100644 --- a/surfsense_backend/app/schemas/vision_llm.py +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel): class GlobalVisionLLMConfigRead(BaseModel): + """Schema for reading global vision LLM configs from YAML. + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + id: int = Field(...) name: str description: str | None = None @@ -73,3 +82,26 @@ class GlobalVisionLLMConfigRead(BaseModel): litellm_params: dict[str, Any] | None = None is_global: bool = True is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + quota_reserve_tokens: int | None = Field( + default=None, + description=( + "Optional override for the per-call reservation in *tokens* — " + "converted to micro-USD via the model's input/output prices at " + "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." + ), + ) + input_cost_per_token: float | None = Field( + default=None, + description=( + "Optional input price in USD/token. Used by pricing_registration to " + "register custom Azure / OpenRouter aliases with LiteLLM at startup." + ), + ) + output_cost_per_token: float | None = Field( + default=None, + description="Optional output price in USD/token. Pair with input_cost_per_token.", + ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py new file mode 100644 index 000000000..f5ca9818e --- /dev/null +++ b/surfsense_backend/app/services/billable_calls.py @@ -0,0 +1,430 @@ +""" +Per-call billable wrapper for image generation, vision LLM extraction, and +any other short-lived premium operation that must charge against the user's +shared premium credit pool. + +The ``billable_call`` async context manager encapsulates the standard +"reserve → execute → finalize / release → record audit row" lifecycle in a +single primitive so callers (the image-generation REST route and the +vision-LLM wrapper used during indexing) don't have to re-implement it. + +KEY DESIGN POINTS (issue A, B): + +1. **Session isolation.** ``billable_call`` takes *no* ``db_session`` + argument. All ``TokenQuotaService.premium_*`` calls and the audit-row + insert each run inside their own ``shielded_async_session()``. This + guarantees that a quota commit/rollback can never accidentally flush or + roll back rows the caller has staged in the request's main session + (e.g. a freshly-created ``ImageGeneration`` row). + +2. **ContextVar safety.** The accumulator is scoped via + :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a + nested ``billable_call`` inside an outer chat turn cannot corrupt the + chat turn's accumulator. + +3. **Free configs are still audited.** Free calls bypass the reserve / + finalize dance entirely but still record a ``TokenUsage`` audit row with + the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution + pipeline complete for analytics even when nothing is debited. + +4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is + responsible for translating that into HTTP 402. We *do not* catch the + denial inside ``billable_call`` — letting it propagate also prevents + the image-generation route from creating an ``ImageGeneration`` row + for a request that never actually ran. +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import shielded_async_session +from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, +) +from app.services.token_tracking_service import ( + TurnTokenAccumulator, + record_token_usage, + scoped_turn, +) + +logger = logging.getLogger(__name__) + + +class QuotaInsufficientError(Exception): + """Raised when ``TokenQuotaService.premium_reserve`` denies a billable + call because the user has exhausted their premium credit pool. + + The route handler should catch this and return HTTP 402 Payment + Required (or the equivalent for the surface area). Outside of the HTTP + layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing) + callers may catch this and degrade gracefully — e.g. fall back to OCR + when vision is unavailable. + """ + + def __init__( + self, + *, + usage_type: str, + used_micros: int, + limit_micros: int, + remaining_micros: int, + ) -> None: + self.usage_type = usage_type + self.used_micros = used_micros + self.limit_micros = limit_micros + self.remaining_micros = remaining_micros + super().__init__( + f"Premium credit exhausted for {usage_type}: " + f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" + ) + + +@asynccontextmanager +async def billable_call( + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None = None, + quota_reserve_micros_override: int | None = None, + usage_type: str, + thread_id: int | None = None, + message_id: int | None = None, + call_details: dict[str, Any] | None = None, +) -> AsyncIterator[TurnTokenAccumulator]: + """Wrap a single billable LLM/image call. + + Args: + user_id: Owner of the credit pool to debit. For vision-LLM during + indexing this is the *search-space owner* (issue M), not the + triggering user. + search_space_id: Required — recorded on the ``TokenUsage`` audit row. + billing_tier: ``"premium"`` debits; anything else (``"free"``) skips + the reserve/finalize dance but still records an audit row with + the captured cost. + base_model: Used by :func:`estimate_call_reserve_micros` to compute + a worst-case reservation from LiteLLM's pricing table. + quota_reserve_tokens: Optional per-config override for the chat-style + reserve estimator (vision LLM uses this). + quota_reserve_micros_override: Optional flat micro-USD reservation + (image generation uses this — its cost shape is per-image, not + per-token). + usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc. + Recorded on the ``TokenUsage`` row. + thread_id, message_id: Optional FK columns on ``TokenUsage``. + call_details: Optional per-call metadata (model name, parameters) + forwarded to ``record_token_usage``. + + Yields: + The ``TurnTokenAccumulator`` scoped to this call. The caller invokes + the underlying LLM/image API while inside the ``async with``; the + ``TokenTrackingCallback`` populates the accumulator automatically. + + Raises: + QuotaInsufficientError: when premium and ``premium_reserve`` denies. + """ + is_premium = billing_tier == "premium" + + async with scoped_turn() as acc: + # ---------- Free path: just audit ------------------------------- + if not is_premium: + try: + yield acc + finally: + # Always audit, even on exception, so we capture cost when + # provider returns successfully but the caller raises later. + try: + async with shielded_async_session() as audit_session: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=acc.total_cost_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + ) + await audit_session.commit() + except Exception: + logger.exception( + "[billable_call] free-path audit insert failed for " + "usage_type=%s user_id=%s", + usage_type, + user_id, + ) + return + + # ---------- Premium path: reserve → execute → finalize ---------- + if quota_reserve_micros_override is not None: + reserve_micros = max(1, int(quota_reserve_micros_override)) + else: + reserve_micros = estimate_call_reserve_micros( + base_model=base_model or "", + quota_reserve_tokens=quota_reserve_tokens, + ) + + request_id = str(uuid4()) + + async with shielded_async_session() as quota_session: + reserve_result = await TokenQuotaService.premium_reserve( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + reserve_micros=reserve_micros, + ) + + if not reserve_result.allowed: + logger.info( + "[billable_call] reserve DENIED user=%s usage_type=%s " + "reserve=%d used=%d limit=%d remaining=%d", + user_id, + usage_type, + reserve_micros, + reserve_result.used, + reserve_result.limit, + reserve_result.remaining, + ) + raise QuotaInsufficientError( + usage_type=usage_type, + used_micros=reserve_result.used, + limit_micros=reserve_result.limit, + remaining_micros=reserve_result.remaining, + ) + + logger.info( + "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d " + "(remaining=%d)", + user_id, + usage_type, + reserve_micros, + reserve_result.remaining, + ) + + try: + yield acc + except BaseException: + # Release on any failure (including QuotaInsufficientError raised + # from a downstream call, asyncio cancellation, etc.). We use + # BaseException so cancellation also releases. + try: + async with shielded_async_session() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] premium_release failed for user=%s " + "reserve_micros=%d (reservation will be GC'd by quota " + "reconciliation if/when implemented)", + user_id, + reserve_micros, + ) + raise + + # ---------- Success: finalize + audit ---------------------------- + actual_micros = acc.total_cost_micros + try: + async with shielded_async_session() as quota_session: + final_result = await TokenQuotaService.premium_finalize( + db_session=quota_session, + user_id=user_id, + request_id=request_id, + actual_micros=actual_micros, + reserved_micros=reserve_micros, + ) + logger.info( + "[billable_call] finalize user=%s usage_type=%s actual=%d " + "reserved=%d → used=%d/%d (remaining=%d)", + user_id, + usage_type, + actual_micros, + reserve_micros, + final_result.used, + final_result.limit, + final_result.remaining, + ) + except Exception: + # Last-ditch: if finalize itself fails, we must at least release + # so the reservation doesn't leak. + logger.exception( + "[billable_call] premium_finalize failed for user=%s; " + "attempting release", + user_id, + ) + try: + async with shielded_async_session() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=user_id, + reserved_micros=reserve_micros, + ) + except Exception: + logger.exception( + "[billable_call] release after finalize failure ALSO failed " + "for user=%s", + user_id, + ) + + try: + async with shielded_async_session() as audit_session: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=actual_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + ) + await audit_session.commit() + except Exception: + logger.exception( + "[billable_call] premium-path audit insert failed for " + "usage_type=%s user_id=%s (debit was applied)", + usage_type, + user_id, + ) + + +async def _resolve_agent_billing_for_search_space( + session: AsyncSession, + search_space_id: int, + *, + thread_id: int | None = None, +) -> tuple[UUID, str, str]: + """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space + agent LLM. + + Used by Celery tasks (podcast generation, video presentation) to bill the + search-space owner's premium credit pool when the agent LLM is premium. + + Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + + - Search space not found / no ``agent_llm_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + * ``thread_id`` is set: delegate to + ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and + recurse into the resolved id. Reuses chat's existing pin if present + so the same model bills for chat + downstream podcast/video. If the + user is not premium-eligible, the pin service auto-restricts to free + deployments — denial only happens later in + ``billable_call.premium_reserve`` if the pin really is premium and + credit ran out mid-flow. + * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat + for any future direct-API path; today both Celery tasks always pass + ``thread_id``. + - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]`` + (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), + ``base_model = litellm_params.get("base_model") or model_name`` — + NOT provider-prefixed, matching chat's cost-map lookup convention. + - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches + ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); + ``base_model`` from ``litellm_params`` or ``model_name``. + + Note on imports: ``llm_service``, ``auto_model_pin_service``, and + ``llm_router_service`` are imported lazily inside the function body to + avoid hoisting litellm side-effects (``litellm.callbacks = + [token_tracker]``, ``litellm.drop_params``, etc.) into + ``billable_calls.py``'s module load path. + """ + from sqlalchemy import select + + from app.db import NewLLMConfig, SearchSpace + + result = await session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if search_space is None: + raise ValueError(f"Search space {search_space_id} not found") + + agent_llm_id = search_space.agent_llm_id + if agent_llm_id is None: + raise ValueError( + f"Search space {search_space_id} has no agent_llm_id configured" + ) + + owner_user_id: UUID = search_space.user_id + + from app.services.auto_model_pin_service import ( + AUTO_FASTEST_ID, + resolve_or_get_pinned_llm_config_id, + ) + + if agent_llm_id == AUTO_FASTEST_ID: + if thread_id is None: + return owner_user_id, "free", "auto" + try: + resolution = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=str(owner_user_id), + selected_llm_config_id=AUTO_FASTEST_ID, + ) + except ValueError: + logger.warning( + "[agent_billing] Auto-mode pin resolution failed for " + "search_space=%s thread=%s; falling back to free", + search_space_id, + thread_id, + exc_info=True, + ) + return owner_user_id, "free", "auto" + agent_llm_id = resolution.resolved_llm_config_id + + if agent_llm_id < 0: + from app.services.llm_service import get_global_llm_config + + cfg = get_global_llm_config(agent_llm_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + litellm_params = cfg.get("litellm_params") or {} + base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" + return owner_user_id, billing_tier, base_model + + nlc_result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == agent_llm_id, + NewLLMConfig.search_space_id == search_space_id, + ) + ) + nlc = nlc_result.scalars().first() + base_model = "" + if nlc is not None: + litellm_params = nlc.litellm_params or {} + base_model = litellm_params.get("base_model") or nlc.model_name or "" + return owner_user_id, "free", base_model + + +__all__ = [ + "QuotaInsufficientError", + "_resolve_agent_billing_for_search_space", + "billable_call", +] + + +# Re-export the config knob so callers don't have to import config just for +# the default image reserve. +DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 8a7b2919a..1e9d235c8 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -134,42 +134,16 @@ PROVIDER_MAP = { } -# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when -# a global LLM config does *not* specify ``api_base``: without this, LiteLLM -# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``, -# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku`` -# request to an Azure endpoint, which then 404s with ``Resource not found``. -# Only providers with a well-known, stable public base URL are listed here — -# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, -# huggingface, databricks, cloudflare, replicate) are intentionally omitted -# so their existing config-driven behaviour is preserved. -PROVIDER_DEFAULT_API_BASE = { - "openrouter": "https://openrouter.ai/api/v1", - "groq": "https://api.groq.com/openai/v1", - "mistral": "https://api.mistral.ai/v1", - "perplexity": "https://api.perplexity.ai", - "xai": "https://api.x.ai/v1", - "cerebras": "https://api.cerebras.ai/v1", - "deepinfra": "https://api.deepinfra.com/v1/openai", - "fireworks_ai": "https://api.fireworks.ai/inference/v1", - "together_ai": "https://api.together.xyz/v1", - "anyscale": "https://api.endpoints.anyscale.com/v1", - "cometapi": "https://api.cometapi.com/v1", - "sambanova": "https://api.sambanova.ai/v1", -} - - -# Canonical provider → base URL when a config uses a generic ``openai``-style -# prefix but the ``provider`` field tells us which API it really is -# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but -# each has its own base URL). -PROVIDER_KEY_DEFAULT_API_BASE = { - "DEEPSEEK": "https://api.deepseek.com/v1", - "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - "MOONSHOT": "https://api.moonshot.ai/v1", - "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", - "MINIMAX": "https://api.minimax.io/v1", -} +# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were +# hoisted to ``app.services.provider_api_base`` so vision and image-gen +# call sites can share the exact same defense (OpenRouter / Groq / etc. +# 404-ing against an inherited Azure endpoint). Re-exported here for +# backward compatibility with any external import. +from app.services.provider_api_base import ( # noqa: E402 + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) class LLMRouterService: @@ -466,14 +440,14 @@ class LLMRouterService: # Resolve ``api_base``. Config value wins; otherwise apply a # provider-aware default so the deployment does not silently # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE`` + # requests to the wrong endpoint. See ``provider_api_base`` # docstring for the motivating bug (OpenRouter models 404-ing # against an Azure endpoint). - api_base = config.get("api_base") - if not api_base: - api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider) - if not api_base: - api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix) + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) if api_base: litellm_params["api_base"] = api_base diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 942a9b7af..72c10035d 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -496,8 +496,14 @@ async def get_vision_llm( - Auto mode (ID 0): VisionLLMRouterService - Global (negative ID): YAML configs - DB (positive ID): VisionLLMConfig table + + Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` + so each ``ainvoke`` debits the search-space owner's premium credit + pool. User-owned BYOK configs and free global configs are returned + unwrapped — they don't consume premium credit (issue M). """ from app.db import VisionLLMConfig + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( VISION_PROVIDER_MAP, VisionLLMRouterService, @@ -519,6 +525,8 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None + owner_user_id = search_space.user_id + if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -526,6 +534,13 @@ async def get_vision_llm( ) return None try: + # Auto mode is currently treated as free at the wrapper + # level — the underlying router can dispatch to either + # premium or free YAML configs but routing decisions are + # opaque. If/when we want to bill Auto-routed vision + # calls we'd need to thread the resolved deployment's + # billing_tier back from the router. For now we keep + # parity with chat Auto, which also doesn't pre-classify. return ChatLiteLLMRouter( router=VisionLLMRouterService.get_router(), streaming=True, @@ -562,8 +577,21 @@ async def get_vision_llm( from app.agents.new_chat.llm_config import SanitizedChatLiteLLM - return SanitizedChatLiteLLM(**litellm_kwargs) + inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) + billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + if billing_tier == "premium": + return QuotaCheckedVisionLLM( + inner_llm, + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=model_string, + quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + ) + return inner_llm + + # User-owned (positive ID) BYOK configs — always free. result = await session.execute( select(VisionLLMConfig).where( VisionLLMConfig.id == config_id, diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 7e856d015..0d030f04f 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -93,6 +93,35 @@ def _is_text_output_model(model: dict) -> bool: return output_mods == ["text"] +def _is_image_output_model(model: dict) -> bool: + """Return True if the model can produce image output. + + OpenRouter's ``architecture.output_modalities`` is a list (e.g. + ``["image"]`` for pure image generators, ``["text", "image"]`` for + multi-modal generators that also emit captions). We accept any model + that can output images; the call site decides whether to use the + image-generation API or chat completion. + """ + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def _is_vision_input_model(model: dict) -> bool: + """Return True if the model can ingest an image AND emit text. + + OpenRouter's ``architecture.input_modalities`` lists what the model + accepts; ``output_modalities`` lists what it produces. A vision LLM + is a model that takes images in and produces text out — i.e. it can + answer questions about a screenshot or extract content from an + image. Pure image-to-image models (e.g. style transfer) and + text-only models are excluded. + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + output_mods = arch.get("output_modalities", []) or [] + return "image" in input_mods and "text" in output_mods + + def _supports_tool_calling(model: dict) -> bool: """Return True if the model supports function/tool calling.""" supported = model.get("supported_parameters") or [] @@ -175,6 +204,32 @@ async def _fetch_models_async() -> list[dict] | None: return None +def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]: + """Return a ``{model_id: {"prompt": str, "completion": str}}`` map. + + Pricing values are kept as the raw OpenRouter strings (e.g. + ``"0.000003"``); ``pricing_registration`` converts them to floats + when registering with LiteLLM. Models with missing or malformed + pricing are simply omitted — operator-side risk if any of those are + premium. + """ + pricing: dict[str, dict[str, str]] = {} + for model in raw_models: + model_id = str(model.get("id") or "").strip() + if not model_id: + continue + p = model.get("pricing") or {} + prompt = p.get("prompt") + completion = p.get("completion") + if prompt is None and completion is None: + continue + pricing[model_id] = { + "prompt": str(prompt) if prompt is not None else "", + "completion": str(completion) if completion is not None else "", + } + return pricing + + def _generate_configs( raw_models: list[dict], settings: dict[str, Any], @@ -282,6 +337,162 @@ def _generate_configs( return configs +# ID-offset bands used to keep dynamic OpenRouter configs in their own +# namespace per surface. Image / vision get separate bands so a single +# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. +_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 +_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 + + +def _generate_image_gen_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter image-generation models into global image-gen + config dicts (matches the YAML shape consumed by ``image_generation_routes``). + + Filter: + - architecture.output_modalities contains "image" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Notably we *drop* the chat-only filters (``_supports_tool_calling`` and + ``_has_sufficient_context``) because tool calls and context windows are + irrelevant for the ``aimage_generation`` API. ``billing_tier`` is + derived per model the same way as chat (``_openrouter_tier``). + + Cost is intentionally *not* registered with LiteLLM at startup + (``pricing_registration`` skips image gen): OpenRouter image-gen + models are not in LiteLLM's native cost map and OpenRouter populates + ``response_cost`` directly from the response header. A defensive + branch in ``_extract_cost_usd`` handles the rare case where + ``usage.cost`` is missing — see ``token_tracking_service``. + """ + id_offset: int = int( + settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + free_rpm: int = settings.get("free_rpm", 20) + litellm_params: dict = settings.get("litellm_params") or {} + + image_models = [ + m + for m in raw_models + if _is_image_output_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in image_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (image generation)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + "api_base": "", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + +def _generate_vision_llm_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter vision-capable LLMs into global vision-LLM config + dicts (matches the YAML shape consumed by ``vision_llm_routes``). + + Filter: + - architecture.input_modalities contains "image" + - architecture.output_modalities contains "text" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Vision-LLM is invoked from the indexer (image extraction during + document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so + the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` + filters do not apply: a small-context vision model that doesn't + advertise tool-calling is still perfectly viable for "describe this + image" prompts. + """ + id_offset: int = int( + settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) + litellm_params: dict = settings.get("litellm_params") or {} + + vision_models = [ + m + for m in raw_models + if _is_vision_input_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in vision_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + pricing = model.get("pricing") or {} + + # Capture per-token prices so ``pricing_registration`` can + # register them with LiteLLM at startup (and so the cost + # estimator in ``estimate_call_reserve_micros`` can resolve + # them at reserve time). + try: + input_cost = float(pricing.get("prompt", 0) or 0) + except (TypeError, ValueError): + input_cost = 0.0 + try: + output_cost = float(pricing.get("completion", 0) or 0) + except (TypeError, ValueError): + output_cost = 0.0 + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (vision)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + "api_base": "", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + "quota_reserve_tokens": quota_reserve_tokens, + "input_cost_per_token": input_cost or None, + "output_cost_per_token": output_cost or None, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -300,6 +511,19 @@ class OpenRouterIntegrationService: # Shape: {model_name: {"gated": bool, "score": float | None}} self._health_cache: dict[str, dict[str, Any]] = {} self._enrich_task: asyncio.Task | None = None + # Raw OpenRouter pricing per model_id, captured at the same time + # we generate configs. Consumed by ``pricing_registration`` to + # teach LiteLLM the per-token cost of every dynamic deployment so + # the success-callback can populate ``response_cost`` correctly. + self._raw_pricing: dict[str, dict[str, str]] = {} + # Cached raw catalogue from the most recent fetch. Image / vision + # emitters reuse this to avoid a second network call per surface. + self._raw_models: list[dict] = [] + # Image / vision config caches (only populated when the matching + # opt-in flag is true on initialize). Refreshed in lockstep with + # the chat catalogue. + self._image_configs: list[dict] = [] + self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -329,8 +553,32 @@ class OpenRouterIntegrationService: self._initialized = True return [] + self._raw_models = raw_models self._configs = _generate_configs(raw_models, settings) self._configs_by_id = {c["id"]: c for c in self._configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + + # Populate image / vision caches when their opt-in flag is set. + # Empty otherwise so the accessors return [] without re-running + # filters every refresh. + if settings.get("image_generation_enabled"): + self._image_configs = _generate_image_gen_configs(raw_models, settings) + logger.info( + "OpenRouter integration: image-gen emission ON (%d models)", + len(self._image_configs), + ) + else: + self._image_configs = [] + + if settings.get("vision_enabled"): + self._vision_configs = _generate_vision_llm_configs(raw_models, settings) + logger.info( + "OpenRouter integration: vision LLM emission ON (%d models)", + len(self._vision_configs), + ) + else: + self._vision_configs = [] + self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -369,6 +617,8 @@ class OpenRouterIntegrationService: new_configs = _generate_configs(raw_models, self._settings) new_by_id = {c["id"]: c for c in new_configs} + self._raw_pricing = _extract_raw_pricing(raw_models) + self._raw_models = raw_models from app.config import config as app_config @@ -382,6 +632,29 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id + # Image / vision lists are atomic-swapped the same way: filter out + # the previous dynamic entries from the live config list and append + # the freshly generated ones. No-ops when the opt-in flag is off. + if self._settings.get("image_generation_enabled"): + new_image = _generate_image_gen_configs(raw_models, self._settings) + static_image = [ + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image + self._image_configs = new_image + + if self._settings.get("vision_enabled"): + new_vision = _generate_vision_llm_configs(raw_models, self._settings) + static_vision = [ + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision + self._vision_configs = new_vision + # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -407,6 +680,21 @@ class OpenRouterIntegrationService: # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) + # Re-register LiteLLM pricing for the freshly fetched catalogue + # so newly added OR models bill correctly on their first call. + # Runs before the router rebuild because the router may issue + # cost-table lookups during deployment registration. + try: + from app.services.pricing_registration import ( + register_pricing_from_global_configs, + ) + + register_pricing_from_global_configs() + except Exception as exc: + logger.warning( + "OpenRouter refresh: pricing re-registration skipped (%s)", exc + ) + # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay # out; a refresh also needs to pick up any static-config edits and @@ -635,3 +923,34 @@ class OpenRouterIntegrationService: def get_config_by_id(self, config_id: int) -> dict | None: return self._configs_by_id.get(config_id) + + def get_image_generation_configs(self) -> list[dict]: + """Return the dynamic OpenRouter image-generation configs (empty + list when the ``image_generation_enabled`` flag is off). + + Each entry already has ``billing_tier`` derived per-model from + OpenRouter's signals and is shaped to drop directly into + ``Config.GLOBAL_IMAGE_GEN_CONFIGS``. + """ + return list(self._image_configs) + + def get_vision_llm_configs(self) -> list[dict]: + """Return the dynamic OpenRouter vision-LLM configs (empty list + when the ``vision_enabled`` flag is off). + + Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` + so ``pricing_registration`` can teach LiteLLM the cost of these + models the same way it does for chat — which keeps the billable + wrapper able to debit accurate micro-USD on a vision call. + """ + return list(self._vision_configs) + + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + """Return the cached raw OpenRouter pricing map. + + Shape: ``{model_id: {"prompt": str, "completion": str}}``. The + values are the strings OpenRouter publishes (USD per token), + never converted to floats here so the caller can decide how to + handle malformed or unset entries. + """ + return dict(self._raw_pricing) diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py new file mode 100644 index 000000000..de98e50c2 --- /dev/null +++ b/surfsense_backend/app/services/pricing_registration.py @@ -0,0 +1,274 @@ +""" +Pricing registration with LiteLLM. + +Many models reach our LiteLLM callback without LiteLLM knowing their +per-token cost — namely: + +* The ~300 dynamic OpenRouter deployments (their pricing only lives on + OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published + pricing table). +* Static YAML deployments whose ``base_model`` name is operator-defined + (e.g. custom Azure deployment names like ``gpt-5.4``) and therefore + not in LiteLLM's table either. + +Without registration, ``kwargs["response_cost"]`` is 0 for those calls +and the user gets billed nothing — a fail-safe but wrong answer for a +cost-based credit system. This module runs once at startup, after the +OpenRouter integration has fetched its catalogue, and registers each +known model's pricing with ``litellm.register_model()`` under multiple +plausible alias keys (LiteLLM's cost lookup may use any of them +depending on whether the call went through the Router, ChatLiteLLM, +or a direct ``acompletion``). + +Operators who run a custom Azure deployment whose ``base_model`` name +isn't in LiteLLM's table can declare per-token pricing inline in +``global_llm_config.yaml`` via ``input_cost_per_token`` and +``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without +that declaration the model's calls debit 0 — never overbilled. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import litellm + +logger = logging.getLogger(__name__) + + +def _safe_float(value: Any) -> float: + """Return ``float(value)`` if it parses to a positive number, else 0.0.""" + if value is None: + return 0.0 + try: + f = float(value) + except (TypeError, ValueError): + return 0.0 + return f if f > 0 else 0.0 + + +def _alias_set_for_openrouter(model_id: str) -> list[str]: + """Return the alias keys to register an OpenRouter model under. + + LiteLLM's cost-callback lookup key varies by call path: + - Router with ``model="openrouter/X"`` → kwargs["model"] is + typically ``openrouter/X``. + - LiteLLM's own provider routing may strip the prefix and pass the + bare ``X`` to the cost-table lookup. + Registering under both keeps the lookup hermetic regardless of + which path the call took. + """ + aliases = [f"openrouter/{model_id}", model_id] + return list(dict.fromkeys(a for a in aliases if a)) + + +def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]: + """Return the alias keys to register a static YAML deployment under. + + Same reasoning as the OpenRouter set: cover the bare ``base_model``, + the ``<provider>/<model>`` form LiteLLM Router constructs, and the + bare ``model_name`` because callbacks sometimes see whichever was + configured first. + """ + provider_lower = (provider or "").lower() + aliases: list[str] = [] + if base_model: + aliases.append(base_model) + if provider_lower and base_model: + aliases.append(f"{provider_lower}/{base_model}") + if model_name and model_name != base_model: + aliases.append(model_name) + if provider_lower and model_name and model_name != base_model: + aliases.append(f"{provider_lower}/{model_name}") + # Azure deployments often surface as "azure/<name>"; normalise the + # ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``. + if provider_lower == "azure_openai": + if base_model: + aliases.append(f"azure/{base_model}") + if model_name and model_name != base_model: + aliases.append(f"azure/{model_name}") + return list(dict.fromkeys(a for a in aliases if a)) + + +def _register( + aliases: list[str], + *, + input_cost: float, + output_cost: float, + provider: str, + mode: str = "chat", +) -> int: + """Register a single pricing entry under every alias in ``aliases``. + + Returns the count of aliases successfully registered. + """ + payload: dict[str, dict[str, Any]] = {} + for alias in aliases: + payload[alias] = { + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost, + "litellm_provider": provider, + "mode": mode, + } + if not payload: + return 0 + try: + litellm.register_model(payload) + except Exception as exc: + logger.warning( + "[PricingRegistration] register_model failed for aliases=%s: %s", + aliases, + exc, + ) + return 0 + return len(payload) + + +def _register_chat_shape_configs( + configs: list[dict], + *, + or_pricing: dict[str, dict[str, str]], + label: str, +) -> tuple[int, int, int, list[str]]: + """Common loop that registers per-token pricing for a list of "chat-shape" + configs (chat or vision LLM — both use ``input_cost_per_token`` / + ``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape). + + Returns ``(registered_models, registered_aliases, skipped, sample_keys)``. + """ + registered_models = 0 + registered_aliases = 0 + skipped_no_pricing = 0 + sample_keys: list[str] = [] + + for cfg in configs: + provider = str(cfg.get("provider") or "").upper() + model_name = str(cfg.get("model_name") or "").strip() + litellm_params = cfg.get("litellm_params") or {} + base_model = str(litellm_params.get("base_model") or model_name).strip() + + if provider == "OPENROUTER": + entry = or_pricing.get(model_name) + if entry: + input_cost = _safe_float(entry.get("prompt")) + output_cost = _safe_float(entry.get("completion")) + else: + # Vision configs from ``_generate_vision_llm_configs`` + # carry their pricing inline because the OpenRouter + # raw-pricing cache is keyed by chat-catalogue model_id; + # vision flows pick up the inline values here. + input_cost = _safe_float(cfg.get("input_cost_per_token")) + output_cost = _safe_float(cfg.get("output_cost_per_token")) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_openrouter(model_name) + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider="openrouter", + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + continue + + input_cost = _safe_float( + cfg.get("input_cost_per_token") + or litellm_params.get("input_cost_per_token") + ) + output_cost = _safe_float( + cfg.get("output_cost_per_token") + or litellm_params.get("output_cost_per_token") + ) + if input_cost == 0.0 and output_cost == 0.0: + skipped_no_pricing += 1 + continue + aliases = _alias_set_for_yaml(provider, model_name, base_model) + provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() + count = _register( + aliases, + input_cost=input_cost, + output_cost=output_cost, + provider=provider_slug, + ) + if count > 0: + registered_models += 1 + registered_aliases += count + if len(sample_keys) < 6: + sample_keys.extend(aliases[:2]) + + logger.info( + "[PricingRegistration:%s] registered pricing for %d models (%d aliases); " + "%d configs had no pricing data; sample registered keys=%s", + label, + registered_models, + registered_aliases, + skipped_no_pricing, + sample_keys, + ) + return registered_models, registered_aliases, skipped_no_pricing, sample_keys + + +def register_pricing_from_global_configs() -> None: + """Register pricing for every known LLM deployment with LiteLLM. + + Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` + so vision calls (during indexing) can resolve cost the same way chat + calls do — namely: + + 1. ``OPENROUTER``: pulls the cached raw pricing from + ``OpenRouterIntegrationService`` (populated during its own + startup fetch) and converts the per-token strings to floats. For + vision configs that carry pricing inline (``input_cost_per_token`` / + ``output_cost_per_token`` set on the cfg itself) we fall back to + those values when the OR cache misses the model. + 2. Anything else: looks for operator-declared + ``input_cost_per_token`` / ``output_cost_per_token`` on the YAML + config block (top-level or nested under ``litellm_params``). + + **Image generation is intentionally NOT registered here.** The cost + shape for image-gen is per-image (``output_cost_per_image``), not + per-token, and LiteLLM's ``register_model`` doesn't accept those + keys via the chat-cost path. OpenRouter image-gen models populate + ``response_cost`` directly from their response header instead, and + Azure-native image-gen models are already in LiteLLM's cost map. + + Calls without a resolved pair of costs are skipped, not registered + with zeros — operators who forget pricing get a "$0 debit" warning + in ``TokenTrackingCallback`` rather than silently overwriting any + pricing LiteLLM might know natively. + """ + from app.config import config as app_config + + chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) + vision_configs: list[dict] = list( + getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] + ) + if not chat_configs and not vision_configs: + logger.info("[PricingRegistration] no global configs to register") + return + + or_pricing: dict[str, dict[str, str]] = {} + try: + from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, + ) + + if OpenRouterIntegrationService.is_initialized(): + or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing() + except Exception as exc: + logger.debug( + "[PricingRegistration] OpenRouter pricing not available yet: %s", exc + ) + + if chat_configs: + _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") + if vision_configs: + _register_chat_shape_configs( + vision_configs, or_pricing=or_pricing, label="vision" + ) diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py new file mode 100644 index 000000000..979d7d3a1 --- /dev/null +++ b/surfsense_backend/app/services/provider_api_base.py @@ -0,0 +1,107 @@ +"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. + +LiteLLM falls back to the module-global ``litellm.api_base`` when an +individual call doesn't pass one, which silently inherits provider-agnostic +env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an +explicit ``api_base``, an ``openrouter/<model>`` request can end up at an +Azure endpoint and 404 with ``Resource not found`` (real reproducer: +[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends +``/chat/completions`` to whatever inherited base it gets, regardless of +provider). + +The chat router has had this defense for a while +(``llm_router_service.py:466-478``). This module hoists the maps + cascade +into a tiny standalone helper so vision and image-gen can share the same +source of truth without an inter-service circular import. +""" + +from __future__ import annotations + + +PROVIDER_DEFAULT_API_BASE: dict[str, str] = { + "openrouter": "https://openrouter.ai/api/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "perplexity": "https://api.perplexity.ai", + "xai": "https://api.x.ai/v1", + "cerebras": "https://api.cerebras.ai/v1", + "deepinfra": "https://api.deepinfra.com/v1/openai", + "fireworks_ai": "https://api.fireworks.ai/inference/v1", + "together_ai": "https://api.together.xyz/v1", + "anyscale": "https://api.endpoints.anyscale.com/v1", + "cometapi": "https://api.cometapi.com/v1", + "sambanova": "https://api.sambanova.ai/v1", +} +"""Default ``api_base`` per LiteLLM provider prefix (lowercase). + +Only providers with a well-known, stable public base URL are listed — +self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, +huggingface, databricks, cloudflare, replicate) are intentionally omitted +so their existing config-driven behaviour is preserved.""" + + +PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { + "DEEPSEEK": "https://api.deepseek.com/v1", + "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "MOONSHOT": "https://api.moonshot.ai/v1", + "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", + "MINIMAX": "https://api.minimax.io/v1", +} +"""Canonical provider key (uppercase) → base URL. + +Used when the LiteLLM provider prefix is the generic ``openai`` shim but the +config's ``provider`` field tells us which API it actually is (DeepSeek, +Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each +has its own base URL).""" + + +def resolve_api_base( + *, + provider: str | None, + provider_prefix: str | None, + config_api_base: str | None, +) -> str | None: + """Resolve a non-Azure-leaking ``api_base`` for a deployment. + + Cascade (first non-empty wins): + 1. The config's own ``api_base`` (whitespace-only treated as missing). + 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. + 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. + 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM + provider integration apply its own default (e.g. AzureOpenAI's + deployment-derived URL, custom provider's per-deployment URL). + + Args: + provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, + ``"DEEPSEEK"``). Case-insensitive. + provider_prefix: The LiteLLM model-string prefix the same call + site builds for the model id (e.g. ``"openrouter"``, + ``"groq"``). Case-insensitive. + config_api_base: ``api_base`` from the global YAML / DB row / + OpenRouter dynamic config. Empty / whitespace-only means + "missing" — the resolver still applies the cascade. + + Returns: + A URL string, or ``None`` if no default applies for this provider. + """ + if config_api_base and config_api_base.strip(): + return config_api_base + + if provider: + key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) + if key_default: + return key_default + + if provider_prefix: + prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) + if prefix_default: + return prefix_default + + return None + + +__all__ = [ + "PROVIDER_DEFAULT_API_BASE", + "PROVIDER_KEY_DEFAULT_API_BASE", + "resolve_api_base", +] diff --git a/surfsense_backend/app/services/quota_checked_vision_llm.py b/surfsense_backend/app/services/quota_checked_vision_llm.py new file mode 100644 index 000000000..0040e5a5b --- /dev/null +++ b/surfsense_backend/app/services/quota_checked_vision_llm.py @@ -0,0 +1,105 @@ +""" +Vision LLM proxy that enforces premium credit quota on every ``ainvoke``. + +Used by :func:`app.services.llm_service.get_vision_llm` so callers in the +indexing pipeline (file processors, connector indexers, etl pipeline) can +keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)`` +— without threading ``user_id`` through every parser. The wrapper looks like +a chat model from the outside; on the inside it routes each call through +``billable_call`` so the user's premium credit pool is reserved → finalized +or released, and a ``TokenUsage`` audit row is written. + +Free configs are returned unwrapped from ``get_vision_llm`` (they do not +need quota enforcement) so this class only ever wraps premium configs. + +Why a wrapper instead of plumbing ``user_id`` through every caller: + +* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive, + Dropbox, local-folder, file-processor, ETL pipeline) each calling + ``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is + invasive, error-prone, and easy for a future indexer to forget. +* Per the design (issue M), we always debit the *search-space owner*, not + the triggering user, so ``user_id`` is fully derivable from the search + space the caller is already operating on. The wrapper captures it once + at construction time. +* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each + call run this coroutine"; subclassing isn't safe across versions because + it derives from ``BaseChatModel`` which expects specific Pydantic shapes. + Composition via attribute proxying (``__getattr__``) is robust to + upstream changes — every method other than ``ainvoke`` falls through to + the inner LLM unchanged. +""" + +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from app.services.billable_calls import QuotaInsufficientError, billable_call + +logger = logging.getLogger(__name__) + + +class QuotaCheckedVisionLLM: + """Composition wrapper around a langchain chat model that enforces + premium credit quota on every ``ainvoke``. + + Anything other than ``ainvoke`` is forwarded to the inner model so + ``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all + still work — they simply bypass quota enforcement, which is fine + because the indexing pipeline only ever calls ``ainvoke`` today. + """ + + def __init__( + self, + inner_llm: Any, + *, + user_id: UUID, + search_space_id: int, + billing_tier: str, + base_model: str, + quota_reserve_tokens: int | None, + usage_type: str = "vision_extraction", + ) -> None: + self._inner = inner_llm + self._user_id = user_id + self._search_space_id = search_space_id + self._billing_tier = billing_tier + self._base_model = base_model + self._quota_reserve_tokens = quota_reserve_tokens + self._usage_type = usage_type + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + """Proxied async invoke that runs the underlying call inside + ``billable_call``. + + Raises: + QuotaInsufficientError: when the user has exhausted their + premium credit pool. Caller (``etl_pipeline_service._extract_image``) + catches this and falls back to the document parser. + """ + async with billable_call( + user_id=self._user_id, + search_space_id=self._search_space_id, + billing_tier=self._billing_tier, + base_model=self._base_model, + quota_reserve_tokens=self._quota_reserve_tokens, + usage_type=self._usage_type, + call_details={"model": self._base_model}, + ): + return await self._inner.ainvoke(input, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Forward everything else (``invoke``, ``astream``, ``bind``, + ``with_structured_output``, …) to the inner model. + + ``__getattr__`` is only consulted when the attribute is *not* + already found on the proxy, which is exactly the contract we + want — methods we override stay on the proxy, the rest fall + through. + """ + return getattr(self._inner, name) + + +__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"] diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py index a3ec7aed0..310c3eb5e 100644 --- a/surfsense_backend/app/services/token_quota_service.py +++ b/surfsense_backend/app/services/token_quota_service.py @@ -22,6 +22,71 @@ from app.config import config logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Per-call reservation estimator (USD micro-units) +# --------------------------------------------------------------------------- + +# Minimum reserve in micros so a user with $0.0001 left can still make a tiny +# request, and so models without registered pricing reserve at least +# something while the call runs (debited 0 at finalize anyway when their +# cost can't be resolved). +_QUOTA_MIN_RESERVE_MICROS = 100 + + +def estimate_call_reserve_micros( + *, + base_model: str, + quota_reserve_tokens: int | None, +) -> int: + """Return the number of micro-USD to reserve for one premium call. + + Computes a worst-case upper bound from LiteLLM's per-token pricing + table: + + reserve_usd ≈ reserve_tokens x (input_cost + output_cost) + + so the math scales with model cost — Claude Opus + 4K reserve_tokens + naturally reserves ≈ $0.36, while a cheap model reserves only a few + cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]`` + so a misconfigured "$1000/M" model can't lock the whole balance on + one call. + + If ``litellm.get_model_info`` raises (model unknown) we fall back to + the floor — 100 micros / $0.0001 — which is enough to gate a sane + request without over-reserving for a model whose pricing the + operator hasn't declared yet. + """ + reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL + if reserve_tokens <= 0: + reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL + + try: + from litellm import get_model_info + + info = get_model_info(base_model) if base_model else {} + input_cost = float(info.get("input_cost_per_token") or 0.0) + output_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception as exc: + logger.debug( + "[quota_reserve] cost lookup failed for base_model=%s: %s", + base_model, + exc, + ) + input_cost = 0.0 + output_cost = 0.0 + + if input_cost == 0.0 and output_cost == 0.0: + return _QUOTA_MIN_RESERVE_MICROS + + reserve_usd = reserve_tokens * (input_cost + output_cost) + reserve_micros = round(reserve_usd * 1_000_000) + if reserve_micros < _QUOTA_MIN_RESERVE_MICROS: + reserve_micros = _QUOTA_MIN_RESERVE_MICROS + if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS: + reserve_micros = config.QUOTA_MAX_RESERVE_MICROS + return reserve_micros + + class QuotaScope(StrEnum): ANONYMOUS = "anonymous" PREMIUM = "premium" @@ -444,8 +509,16 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - reserve_tokens: int, + reserve_micros: int, ) -> QuotaResult: + """Reserve ``reserve_micros`` (USD micro-units) from the user's + premium credit balance. + + ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are + all in micro-USD on this code path; callers (chat stream, + token-status route, FE display) convert to dollars by dividing + by 1_000_000. + """ from app.db import User user = ( @@ -465,11 +538,11 @@ class TokenQuotaService: limit=0, ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved - effective = used + reserved + reserve_tokens + effective = used + reserved + reserve_micros if effective > limit: remaining = max(0, limit - used - reserved) await db_session.rollback() @@ -482,10 +555,10 @@ class TokenQuotaService: remaining=remaining, ) - user.premium_tokens_reserved = reserved + reserve_tokens + user.premium_credit_micros_reserved = reserved + reserve_micros await db_session.commit() - new_reserved = reserved + reserve_tokens + new_reserved = reserved + reserve_micros remaining = max(0, limit - used - new_reserved) warning_threshold = int(limit * 0.8) @@ -510,9 +583,12 @@ class TokenQuotaService: db_session: AsyncSession, user_id: Any, request_id: str, - actual_tokens: int, - reserved_tokens: int, + actual_micros: int, + reserved_micros: int, ) -> QuotaResult: + """Settle the reservation: release ``reserved_micros`` and debit + ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD). + """ from app.db import User user = ( @@ -529,16 +605,18 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros + ) + user.premium_credit_micros_used = ( + user.premium_credit_micros_used + actual_micros ) - user.premium_tokens_used = user.premium_tokens_used + actual_tokens await db_session.commit() - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) @@ -562,8 +640,13 @@ class TokenQuotaService: async def premium_release( db_session: AsyncSession, user_id: Any, - reserved_tokens: int, + reserved_micros: int, ) -> None: + """Release ``reserved_micros`` previously held by ``premium_reserve``. + + Used when a request fails before finalize (so the reservation + doesn't leak credit). + """ from app.db import User user = ( @@ -576,8 +659,8 @@ class TokenQuotaService: .scalar_one_or_none() ) if user is not None: - user.premium_tokens_reserved = max( - 0, user.premium_tokens_reserved - reserved_tokens + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros ) await db_session.commit() @@ -598,9 +681,9 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - limit = user.premium_tokens_limit - used = user.premium_tokens_used - reserved = user.premium_tokens_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved remaining = max(0, limit - used - reserved) warning_threshold = int(limit * 0.8) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 9aa8c6e70..9406d9be4 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -16,11 +16,14 @@ from __future__ import annotations import dataclasses import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from contextvars import ContextVar from dataclasses import dataclass, field from typing import Any from uuid import UUID +import litellm from litellm.integrations.custom_logger import CustomLogger from sqlalchemy.ext.asyncio import AsyncSession @@ -35,6 +38,8 @@ class TokenCallRecord: prompt_tokens: int completion_tokens: int total_tokens: int + cost_micros: int = 0 + call_kind: str = "chat" @dataclass @@ -49,6 +54,8 @@ class TurnTokenAccumulator: prompt_tokens: int, completion_tokens: int, total_tokens: int, + cost_micros: int = 0, + call_kind: str = "chat", ) -> None: self.calls.append( TokenCallRecord( @@ -56,20 +63,28 @@ class TurnTokenAccumulator: prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) ) def per_message_summary(self) -> dict[str, dict[str, int]]: - """Return token counts grouped by model name.""" + """Return token counts (and cost) grouped by model name.""" by_model: dict[str, dict[str, int]] = {} for c in self.calls: entry = by_model.setdefault( c.model, - {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost_micros": 0, + }, ) entry["prompt_tokens"] += c.prompt_tokens entry["completion_tokens"] += c.completion_tokens entry["total_tokens"] += c.total_tokens + entry["cost_micros"] += c.cost_micros return by_model @property @@ -84,6 +99,21 @@ class TurnTokenAccumulator: def total_completion_tokens(self) -> int: return sum(c.completion_tokens for c in self.calls) + @property + def total_cost_micros(self) -> int: + """Sum of per-call ``cost_micros`` across the entire turn. + + Used by ``stream_new_chat`` to debit a premium turn's actual + provider cost (in micro-USD) from the user's premium credit + balance. ``cost_micros`` per call is captured by + ``TokenTrackingCallback.async_log_success_event`` from + ``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost), + with multiple fallback paths so OpenRouter dynamic models and + custom Azure deployments still bill correctly when our + ``pricing_registration`` ran at startup. + """ + return sum(c.cost_micros for c in self.calls) + def serialized_calls(self) -> list[dict[str, Any]]: return [dataclasses.asdict(c) for c in self.calls] @@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar( def start_turn() -> TurnTokenAccumulator: - """Create a fresh accumulator for the current async context and return it.""" + """Create a fresh accumulator for the current async context and return it. + + NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For + short-lived per-call billable wrappers (image generation REST endpoint, + vision LLM during indexing) prefer :func:`scoped_turn`, which uses a + ContextVar reset token to restore the *previous* accumulator on exit and + avoids leaking call records across reservations (issue B). + """ acc = TurnTokenAccumulator() _turn_accumulator.set(acc) logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc)) @@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() +@asynccontextmanager +async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: + """Async context manager that scopes a fresh ``TurnTokenAccumulator`` + for the duration of the ``async with`` block, then *resets* the + ContextVar to its previous value on exit. + + This is the safe primitive for per-call billable operations + (image generation, vision LLM extraction, podcasts) that may run + inside an outer chat turn or be called sequentially from the same + background worker. Using ``ContextVar.set`` without ``reset`` (as + :func:`start_turn` does) would leak the inner accumulator into the + outer scope, causing the outer chat turn to debit cost twice. + + Usage:: + + async with scoped_turn() as acc: + await llm.ainvoke(...) + # acc.total_cost_micros captures cost from the LiteLLM callback + # Outer accumulator (if any) is restored here. + """ + acc = TurnTokenAccumulator() + token = _turn_accumulator.set(acc) + logger.debug( + "[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)", + id(acc), + token, + ) + try: + yield acc + finally: + _turn_accumulator.reset(token) + logger.debug( + "[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)", + id(acc), + len(acc.calls), + acc.total_cost_micros, + ) + + +def _extract_cost_usd( + kwargs: dict[str, Any], + response_obj: Any, + model: str, + prompt_tokens: int, + completion_tokens: int, + is_image: bool = False, +) -> float: + """Best-effort USD cost extraction for a single LLM/image call. + + Tries four sources in priority order and returns the first that + yields a positive number; returns 0.0 if all four fail (the call + will then debit nothing from the user's balance — fail-safe). + + Sources: + 1. ``kwargs["response_cost"]`` — LiteLLM's standard callback + field, populated for ``Router.acompletion`` since PR #12500. + 2. ``response_obj._hidden_params["response_cost"]`` — same value + exposed on the response itself. + 3. ``litellm.completion_cost(completion_response=response_obj)`` + — recompute from the response and LiteLLM's pricing table. + 4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)`` + — manual fallback for OpenRouter/custom-Azure models that + only resolve via aliases registered by + ``pricing_registration`` at startup. **Skipped for image + responses** — ``cost_per_token`` does not support ``ImageResponse`` + and would raise; the cost map for image-gen lives in different + keys (``output_cost_per_image``) handled by ``completion_cost``. + """ + cost = kwargs.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + hidden = getattr(response_obj, "_hidden_params", None) or {} + if isinstance(hidden, dict): + cost = hidden.get("response_cost") + if cost is not None: + try: + value = float(cost) + except (TypeError, ValueError): + value = 0.0 + if value > 0: + return value + + try: + value = float(litellm.completion_cost(completion_response=response_obj)) + if value > 0: + return value + except Exception as exc: + if is_image: + # Image-gen path: OpenRouter's image responses can omit + # ``usage.cost`` and LiteLLM's ``default_image_cost_calculator`` + # then *raises* (no cost map for OpenRouter image models). + # Bail out with a warning rather than falling through to + # cost_per_token (which is also incompatible with ImageResponse). + logger.warning( + "[TokenTracking] completion_cost failed for image model=%s " + "(provider may have omitted usage.cost). Debiting 0. " + "Cause: %s", + model, + exc, + ) + return 0.0 + logger.debug( + "[TokenTracking] completion_cost failed for model=%s: %s", model, exc + ) + + if is_image: + # Never call cost_per_token for ImageResponse — keys mismatch and + # the function is documented chat-only. + return 0.0 + + if model and (prompt_tokens > 0 or completion_tokens > 0): + try: + prompt_cost, completion_cost = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + value = float(prompt_cost) + float(completion_cost) + if value > 0: + return value + except Exception as exc: + logger.debug( + "[TokenTracking] cost_per_token failed for model=%s: %s", model, exc + ) + + return 0.0 + + class TokenTrackingCallback(CustomLogger): """LiteLLM callback that captures token usage into the turn accumulator.""" @@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger): ) return + # Detect image generation responses — they have a different usage + # shape (ImageUsage with input_tokens/output_tokens) and require a + # different cost-extraction path. We probe by class name to avoid a + # hard import dependency on litellm internals. + response_cls = type(response_obj).__name__ + is_image = response_cls == "ImageResponse" + usage = getattr(response_obj, "usage", None) if not usage: logger.debug( @@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger): ) return - prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 - completion_tokens = getattr(usage, "completion_tokens", 0) or 0 - total_tokens = getattr(usage, "total_tokens", 0) or 0 + if is_image: + # ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens`` + # (not prompt_tokens/completion_tokens). Several providers + # populate only one or neither (e.g. OpenRouter's gpt-image-1 + # passes through `input_tokens` from the prompt but no + # completion); fall through gracefully to 0. + prompt_tokens = getattr(usage, "input_tokens", 0) or 0 + completion_tokens = getattr(usage, "output_tokens", 0) or 0 + total_tokens = ( + getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens + ) + call_kind = "image_generation" + else: + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + total_tokens = getattr(usage, "total_tokens", 0) or 0 + call_kind = "chat" model = kwargs.get("model", "unknown") + cost_usd = _extract_cost_usd( + kwargs=kwargs, + response_obj=response_obj, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_image=is_image, + ) + cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0 + + if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0): + logger.warning( + "[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d " + "kind=%s — debiting 0. Register pricing via pricing_registration or YAML " + "input_cost_per_token/output_cost_per_token (or rely on response_cost " + "for image generation).", + model, + prompt_tokens, + completion_tokens, + call_kind, + ) + acc.add( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, + call_kind=call_kind, ) logger.info( - "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", + "[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d " + "cost=$%.6f (%d micros) (accumulator now has %d calls)", model, + call_kind, prompt_tokens, completion_tokens, total_tokens, + cost_usd, + cost_micros, len(acc.calls), ) @@ -168,6 +388,7 @@ async def record_token_usage( prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0, + cost_micros: int = 0, model_breakdown: dict[str, Any] | None = None, call_details: dict[str, Any] | None = None, thread_id: int | None = None, @@ -185,6 +406,7 @@ async def record_token_usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cost_micros=cost_micros, model_breakdown=model_breakdown, call_details=call_details, thread_id=thread_id, @@ -194,11 +416,12 @@ async def record_token_usage( ) session.add(record) logger.debug( - "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", + "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d", usage_type, prompt_tokens, completion_tokens, total_tokens, + cost_micros, ) return record except Exception: diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0d782ab2b..ed5de921c 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,6 +3,8 @@ from typing import Any from litellm import Router +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 @@ -108,10 +110,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" else: - provider = config.get("provider", "").upper() provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{config['model_name']}" @@ -120,8 +123,13 @@ class VisionLLMRouterService: "api_key": config.get("api_key"), } - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base if config.get("api_version"): litellm_params["api_version"] = config["api_version"] diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 953011ecf..937877473 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -9,7 +9,13 @@ from sqlalchemy import select from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState from app.celery_app import celery_app +from app.config import config as app_config from app.db import Podcast, PodcastStatus +from app.services.billable_calls import ( + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) @@ -96,6 +102,31 @@ async def _generate_content_podcast( podcast.status = PodcastStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=podcast.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "Podcast %s: cannot resolve billing for search_space=%s: %s", + podcast.id, + search_space_id, + resolve_err, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "podcast_title": podcast.title, @@ -109,9 +140,39 @@ async def _generate_content_podcast( db_session=session, ) - graph_result = await podcaster_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, + usage_type="podcast_generation", + thread_id=podcast.thread_id, + call_details={ + "podcast_id": podcast.id, + "title": podcast.title, + }, + ): + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "Podcast %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + podcast.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "premium_quota_exhausted", + } podcast_transcript = graph_result.get("podcast_transcript", []) file_path = graph_result.get("final_podcast_file_path", "") diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index 7880b385f..4f0c427d9 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -9,7 +9,13 @@ from sqlalchemy import select from app.agents.video_presentation.graph import graph as video_presentation_graph from app.agents.video_presentation.state import State as VideoPresentationState from app.celery_app import celery_app +from app.config import config as app_config from app.db import VideoPresentation, VideoPresentationStatus +from app.services.billable_calls import ( + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) @@ -97,6 +103,32 @@ async def _generate_video_presentation( video_pres.status = VideoPresentationStatus.GENERATING await session.commit() + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=video_pres.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "VideoPresentation %s: cannot resolve billing for " + "search_space=%s: %s", + video_pres.id, + search_space_id, + resolve_err, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_resolution_failed", + } + graph_config = { "configurable": { "video_title": video_pres.title, @@ -110,9 +142,39 @@ async def _generate_video_presentation( db_session=session, ) - graph_result = await video_presentation_graph.ainvoke( - initial_state, config=graph_config - ) + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, + usage_type="video_presentation_generation", + thread_id=video_pres.thread_id, + call_details={ + "video_presentation_id": video_pres.id, + "title": video_pres.title, + }, + ): + graph_result = await video_presentation_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "VideoPresentation %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + video_pres.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "premium_quota_exhausted", + } # Serialize slides (parsed content + audio info merged) slides_raw = graph_result.get("slides", []) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index dbfe9a67b..31c0d7d6d 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2236,8 +2236,10 @@ async def stream_new_chat( accumulator = start_turn() - # Premium quota tracking state - _premium_reserved = 0 + # Premium credit (USD micro-units) tracking state. Stores the + # amount reserved up front so we can release it on cancellation + # and finalize-debit the actual provider cost reported by LiteLLM. + _premium_reserved_micros = 0 _premium_request_id: str | None = None _emit_stream_error = partial( @@ -2331,23 +2333,28 @@ async def stream_new_chat( if _needs_premium_quota: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _agent_litellm_params = agent_config.litellm_params or {} + _agent_base_model = ( + _agent_litellm_params.get("base_model") or agent_config.model_name or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_agent_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _premium_reserved = reserve_amount + _premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: if requested_llm_config_id == 0: try: @@ -2382,7 +2389,7 @@ async def stream_new_chat( yield streaming_service.format_done() return _premium_request_id = None - _premium_reserved = 0 + _premium_reserved_micros = 0 _log_chat_stream_error( flow=flow, error_kind="premium_quota_exhausted", @@ -3020,9 +3027,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3033,6 +3041,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3060,7 +3069,11 @@ async def stream_new_chat( chat_id, generated_title ) - # Finalize premium quota with actual tokens. + # Finalize premium credit debit with the actual provider cost + # reported by LiteLLM, summed across every call in the turn. + # Mirrors the pre-cost behaviour of "premium turn → all calls + # count" so free sub-agent calls during a premium turn still + # contribute to the bill (they're $0 in practice anyway). if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3070,11 +3083,11 @@ async def stream_new_chat( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=accumulator.grand_total, - reserved_tokens=_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_premium_reserved_micros, ) _premium_request_id = None - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -3084,9 +3097,10 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal new_chat: calls=%d total=%d summary=%s", + "[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3097,6 +3111,7 @@ async def stream_new_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3190,7 +3205,7 @@ async def stream_new_chat( end_turn(str(chat_id)) # Release premium reservation if not finalized - if _premium_request_id and _premium_reserved > 0 and user_id: + if _premium_request_id and _premium_reserved_micros > 0 and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3198,9 +3213,9 @@ async def stream_new_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_premium_reserved, + reserved_micros=_premium_reserved_micros, ) - _premium_reserved = 0 + _premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s", user_id @@ -3369,8 +3384,8 @@ async def stream_resume_chat( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) - # Premium quota reservation (same logic as stream_new_chat) - _resume_premium_reserved = 0 + # Premium credit reservation (same logic as stream_new_chat). + _resume_premium_reserved_micros = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( agent_config is not None and user_id and agent_config.is_premium @@ -3378,23 +3393,30 @@ async def stream_resume_chat( if _resume_needs_premium: import uuid as _uuid - from app.config import config as _app_config - from app.services.token_quota_service import TokenQuotaService + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) _resume_premium_request_id = _uuid.uuid4().hex[:16] - reserve_amount = min( - agent_config.quota_reserve_tokens - or _app_config.QUOTA_MAX_RESERVE_PER_CALL, - _app_config.QUOTA_MAX_RESERVE_PER_CALL, + _resume_litellm_params = agent_config.litellm_params or {} + _resume_base_model = ( + _resume_litellm_params.get("base_model") + or agent_config.model_name + or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=_resume_base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - reserve_tokens=reserve_amount, + reserve_micros=reserve_amount_micros, ) - _resume_premium_reserved = reserve_amount + _resume_premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: if requested_llm_config_id == 0: try: @@ -3429,7 +3451,7 @@ async def stream_resume_chat( yield streaming_service.format_done() return _resume_premium_request_id = None - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 _log_chat_stream_error( flow="resume", error_kind="premium_quota_exhausted", @@ -3746,9 +3768,10 @@ async def stream_resume_chat( if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", + "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3759,6 +3782,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3768,7 +3792,9 @@ async def stream_resume_chat( yield streaming_service.format_done() return - # Finalize premium quota for resume path + # Finalize premium credit debit for resume path with the actual + # provider cost reported by LiteLLM (sum of cost across all + # calls in the turn). if _resume_premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService @@ -3778,11 +3804,11 @@ async def stream_resume_chat( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=accumulator.grand_total, - reserved_tokens=_resume_premium_reserved, + actual_micros=accumulator.total_cost_micros, + reserved_micros=_resume_premium_reserved_micros, ) _resume_premium_request_id = None - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", @@ -3792,9 +3818,10 @@ async def stream_resume_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( - "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", + "[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s", len(accumulator.calls), accumulator.grand_total, + accumulator.total_cost_micros, usage_summary, ) if usage_summary: @@ -3805,6 +3832,7 @@ async def stream_resume_chat( "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, "call_details": accumulator.serialized_calls(), }, ) @@ -3855,7 +3883,11 @@ async def stream_resume_chat( end_turn(str(chat_id)) # Release premium reservation if not finalized - if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: + if ( + _resume_premium_request_id + and _resume_premium_reserved_micros > 0 + and user_id + ): try: from app.services.token_quota_service import TokenQuotaService @@ -3863,9 +3895,9 @@ async def stream_resume_chat( await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), - reserved_tokens=_resume_premium_reserved, + reserved_micros=_resume_premium_reserved_micros, ) - _resume_premium_reserved = 0 + _resume_premium_reserved_micros = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s (resume)", user_id diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py new file mode 100644 index 000000000..636b7de31 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -0,0 +1,138 @@ +"""Unit tests for the image-generation route's billing-resolution helper. + +End-to-end "POST /image-generations returns 402" coverage requires the +integration harness (real DB, real auth) and lives in +``tests/integration/document_upload/`` alongside the other quota tests. +This unit test focuses on the new ``_resolve_billing_for_image_gen`` +helper which: + +* Returns ``free`` for Auto mode, even when premium configs exist + (Auto-mode billing-tier surfacing is a follow-up). +* Returns ``free`` for user-owned BYOK configs (positive IDs). +* Returns the global config's ``billing_tier`` for negative IDs. +* Honours the per-config ``quota_reserve_micros`` override when present. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_resolve_billing_for_auto_mode(monkeypatch): + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, # Not consumed on this code path. + config_id=0, # IMAGE_GEN_AUTO_MODE_ID + search_space=search_space, + ) + assert tier == "free" + assert model == "auto" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_premium_global_config(monkeypatch): + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + "quota_reserve_micros": 75_000, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "billing_tier": "free", + }, + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=None) + + # Premium with override. + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-1, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" + assert reserve == 75_000 + + # Free, no override → falls back to default. + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=-2, search_space=search_space + ) + assert tier == "free" + # Provider-prefixed model string for OpenRouter. + assert "google/gemini-2.5-flash-image" in model + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_for_user_owned_byok_is_free(): + """User-owned BYOK configs (positive IDs) cost the user nothing on + our side — they pay the provider directly. Always free. + """ + from app.routes import image_generation_routes + from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS + + search_space = SimpleNamespace(image_generation_config_id=None) + tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=42, search_space=search_space + ) + assert tier == "free" + assert model == "user_byok" + assert reserve == DEFAULT_IMAGE_RESERVE_MICROS + + +@pytest.mark.asyncio +async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): + """When the request omits ``image_generation_config_id``, the helper + must consult the search space's default — so a search space pinned + to a premium global config still gates new requests by quota. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + { + "id": -7, + "provider": "OPENAI", + "model_name": "gpt-image-1", + "billing_tier": "premium", + } + ], + raising=False, + ) + + search_space = SimpleNamespace(image_generation_config_id=-7) + ( + tier, + model, + _reserve, + ) = await image_generation_routes._resolve_billing_for_image_gen( + session=None, config_id=None, search_space=search_space + ) + assert tier == "premium" + assert model == "openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py new file mode 100644 index 000000000..fa8819b39 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -0,0 +1,436 @@ +"""Unit tests for ``_resolve_agent_billing_for_search_space``. + +Validates the resolver used by Celery podcast/video tasks to compute +``(owner_user_id, billing_tier, base_model)`` from a search space and its +agent LLM config. The resolver mirrors chat's billing-resolution pattern at +``stream_new_chat.py:2294-2351`` and is the single integration point that +prevents Auto-mode podcast/video from leaking premium credit. + +Coverage: + +* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium + global → returns ``("premium", <base_model>)``. +* Auto mode + ``thread_id`` set, pin resolves to a negative-id free + global → returns ``("free", <base_model>)``. +* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config + → always ``"free"``. +* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without + hitting the pin service. +* Negative id (no Auto) → uses ``get_global_llm_config``'s + ``billing_tier``. +* Positive id (user BYOK) → always ``"free"``. +* Search space not found → raises ``ValueError``. +* ``agent_llm_id`` is None → raises ``ValueError``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from uuid import UUID, uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + +class _FakeSession: + """Tiny AsyncSession stub. + + ``responses`` is a list of objects to return from successive + ``execute()`` calls (in order). The resolver makes at most two + ``execute()`` calls (search-space lookup, then optionally NewLLMConfig + lookup), so two queued responses cover the matrix. + """ + + def __init__(self, responses: list): + self._responses = list(responses) + + async def execute(self, _stmt): + if not self._responses: + return _FakeExecResult(None) + return _FakeExecResult(self._responses.pop(0)) + + async def commit(self) -> None: + pass + + +@dataclass +class _FakePinResolution: + resolved_llm_config_id: int + resolved_tier: str = "premium" + from_existing_pin: bool = False + + +def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace( + id=42, + agent_llm_id=agent_llm_id, + user_id=user_id, + ) + + +def _make_byok_config( + *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +) -> SimpleNamespace: + return SimpleNamespace( + id=id_, + model_name=model_name, + litellm_params={"base_model": base_model} if base_model else {}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): + """Auto + thread → pin service resolves to negative-id premium config → + resolver returns ``("premium", <base_model>)``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + # Mock the pin service to return a concrete premium config id. + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + assert selected_llm_config_id == 0 + assert thread_id == 99 + return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") + + # Mock global config lookup to return a premium entry. + def _fake_get_global(cfg_id): + if cfg_id == -1: + return { + "id": -1, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + return None + + # Lazy imports inside the resolver — patch the *target* modules so the + # imported names resolve to our fakes. + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): + """Auto + thread → pin returns negative-id free config → resolver + returns ``("free", <base_model>)``. Same path the pin service takes for + out-of-credit users (graceful degradation).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") + + def _fake_get_global(cfg_id): + if cfg_id == -3: + return { + "id": -3, + "model_name": "openrouter/free-model", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/free-model"}, + } + return None + + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/free-model" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): + """Auto + thread → pin returns positive-id BYOK config → resolver + returns ``("free", ...)`` (BYOK is always free per + ``AgentConfig.from_new_llm_config``).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=0, user_id=user_id) + byok_cfg = _make_byok_config( + id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + ) + session = _FakeSession([search_space, byok_cfg]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3-haiku" + + +@pytest.mark.asyncio +async def test_auto_mode_without_thread_id_falls_back_to_free(): + """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking + the pin service. Forward-compat fallback for any future direct-API + entrypoint that doesn't have a chat thread.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): + """If the pin service raises ``ValueError`` (thread missing / + mismatched search space), the resolver should log and return free + rather than killing the whole task.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + + async def _fake_resolve_pin(*args, **kwargs): + raise ValueError("thread missing") + + import app.services.auto_model_pin_service as pin_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "auto" + + +@pytest.mark.asyncio +async def test_negative_id_premium_global_returns_premium(monkeypatch): + """Explicit negative agent_llm_id → ``get_global_llm_config`` → + return its ``billing_tier``.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "gpt-5.4", + "billing_tier": "premium", + "litellm_params": {"base_model": "gpt-5.4"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "premium" + assert base_model == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_negative_id_free_global_returns_free(monkeypatch): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "openrouter/some-free", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/some-free"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/some-free" + + +@pytest.mark.asyncio +async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): + """When the global config has no ``litellm_params.base_model``, the + resolver falls back to ``model_name`` — matching chat's behavior.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "fallback-model", + "billing_tier": "premium", + # No litellm_params. + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + _, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert tier == "premium" + assert base_model == "fallback-model" + + +@pytest.mark.asyncio +async def test_positive_id_byok_is_always_free(): + """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, + regardless of underlying provider tier.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=23, user_id=user_id) + byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_cfg]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "anthropic/claude-3.5-sonnet" + + +@pytest.mark.asyncio +async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): + """If the BYOK config row is missing/deleted but the search space still + points at it, the resolver still returns free (no debit) with an empty + base_model — billable_call's premium path is skipped, no harm done.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "" + + +@pytest.mark.asyncio +async def test_search_space_not_found_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + session = _FakeSession([None]) + + with pytest.raises(ValueError, match="Search space"): + await _resolve_agent_billing_for_search_space(session, search_space_id=999) + + +@pytest.mark.asyncio +async def test_agent_llm_id_none_raises_value_error(): + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + + with pytest.raises(ValueError, match="agent_llm_id"): + await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py new file mode 100644 index 000000000..86de5f23d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -0,0 +1,432 @@ +"""Unit tests for the ``billable_call`` async context manager. + +Covers the per-call premium-credit lifecycle for image generation and +vision LLM extraction: + +* Free configs bypass reserve/finalize but still write an audit row. +* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the + route layer). +* Successful premium calls reserve, yield the accumulator, then finalize + with the LiteLLM-reported actual cost — and write an audit row. +* Failed premium calls release the reservation so credit isn't leaked. +* All quota DB ops happen inside their OWN ``shielded_async_session``, + isolating them from the caller's transaction (issue A). +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeQuotaResult: + def __init__( + self, + *, + allowed: bool, + used: int = 0, + limit: int = 5_000_000, + remaining: int = 5_000_000, + ) -> None: + self.allowed = allowed + self.used = used + self.limit = limit + self.remaining = remaining + + +class _FakeSession: + """Minimal AsyncSession stub — record commits for assertion.""" + + def __init__(self) -> None: + self.committed = False + self.added: list[Any] = [] + + def add(self, obj: Any) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.committed = True + + async def close(self) -> None: + pass + + +@contextlib.asynccontextmanager +async def _fake_shielded_session(): + s = _FakeSession() + _SESSIONS_USED.append(s) + yield s + + +_SESSIONS_USED: list[_FakeSession] = [] + + +def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None): + """Wire fake reserve/finalize/release/session helpers.""" + _SESSIONS_USED.clear() + reserve_calls: list[dict[str, Any]] = [] + finalize_calls: list[dict[str, Any]] = [] + release_calls: list[dict[str, Any]] = [] + + async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros): + reserve_calls.append( + { + "user_id": user_id, + "reserve_micros": reserve_micros, + "request_id": request_id, + } + ) + return reserve_result + + async def _fake_finalize( + *, db_session, user_id, request_id, actual_micros, reserved_micros + ): + finalize_calls.append( + { + "user_id": user_id, + "actual_micros": actual_micros, + "reserved_micros": reserved_micros, + } + ) + return finalize_result or _FakeQuotaResult(allowed=True) + + async def _fake_release(*, db_session, user_id, reserved_micros): + release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros}) + + record_calls: list[dict[str, Any]] = [] + + async def _fake_record(session, **kwargs): + record_calls.append(kwargs) + return object() + + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_reserve", + _fake_reserve, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_finalize", + _fake_finalize, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.TokenQuotaService.premium_release", + _fake_release, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.shielded_async_session", + _fake_shielded_session, + raising=False, + ) + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _fake_record, + raising=False, + ) + + return { + "reserve": reserve_calls, + "finalize": finalize_calls, + "release": release_calls, + "record": record_calls, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + ) as acc: + # Simulate a captured cost — the accumulator is fed by the LiteLLM + # callback in real life, here we add it manually. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + # Free still audits. + assert len(spies["record"]) == 1 + assert spies["record"][0]["usage_type"] == "image_generation" + assert spies["record"][0]["cost_micros"] == 37_000 + + +@pytest.mark.asyncio +async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=5_000_000, limit=5_000_000, remaining=0 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "image_generation" + assert err.used_micros == 5_000_000 + assert err.limit_micros == 5_000_000 + assert err.remaining_micros == 0 + # Reserve was attempted, but no finalize/release on a denied reserve + # — the reservation never actually held credit. + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + assert spies["release"] == [] + # Denied premium calls do NOT create an audit row (no work happened). + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_success_finalizes_with_actual_cost(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + # LiteLLM callback would normally fill this — simulate $0.04 image. + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert spies["reserve"][0]["reserve_micros"] == 50_000 + assert len(spies["finalize"]) == 1 + assert spies["finalize"][0]["actual_micros"] == 40_000 + assert spies["finalize"][0]["reserved_micros"] == 50_000 + assert spies["release"] == [] + # And audit row written with the actual debited cost. + assert spies["record"][0]["cost_micros"] == 40_000 + # Each quota op opened its OWN session — proves session isolation. + assert len(_SESSIONS_USED) >= 3 + # Sessions used should each have committed (or be the audit one which commits). + for _s in _SESSIONS_USED: + # finalize/reserve happen via TokenQuotaService.* which we stub — + # they don't actually call commit on our fake session, but the + # audit session does. We just assert >=1 session committed. + pass + assert any(s.committed for s in _SESSIONS_USED) + + +@pytest.mark.asyncio +async def test_premium_failure_releases_reservation(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _ProviderError(Exception): + pass + + with pytest.raises(_ProviderError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ): + raise _ProviderError("OpenRouter 503") + + assert len(spies["reserve"]) == 1 + assert spies["finalize"] == [] + # Failure path: release the held reservation. + assert len(spies["release"]) == 1 + assert spies["release"][0]["reserved_micros"] == 50_000 + + +@pytest.mark.asyncio +async def test_premium_uses_estimator_when_no_micros_override(monkeypatch): + """When ``quota_reserve_micros_override`` is None we fall back to + ``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``. + Vision LLM calls take this path (token-priced models). + """ + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + captured_estimator_calls: list[dict[str, Any]] = [] + + def _fake_estimate(*, base_model, quota_reserve_tokens): + captured_estimator_calls.append( + {"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens} + ) + return 12_345 + + monkeypatch.setattr( + "app.services.billable_calls.estimate_call_reserve_micros", + _fake_estimate, + raising=False, + ) + + user_id = uuid4() + async with billable_call( + user_id=user_id, + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + usage_type="vision_extraction", + ): + pass + + assert captured_estimator_calls == [ + {"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000} + ] + assert spies["reserve"][0]["reserve_micros"] == 12_345 + + +# --------------------------------------------------------------------------- +# Podcast / video-presentation usage_type coverage +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch): + """Free podcast configs must skip reserve/finalize but still emit a + ``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we + have full audit coverage of free-tier agent runs.""" + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="free", + base_model="openrouter/some-free-model", + quota_reserve_micros_override=200_000, + usage_type="podcast_generation", + thread_id=99, + call_details={"podcast_id": 7, "title": "Test Podcast"}, + ) as acc: + # Two transcript LLM calls aggregated into one accumulator. + acc.add( + model="openrouter/some-free-model", + prompt_tokens=1500, + completion_tokens=8000, + total_tokens=9500, + cost_micros=0, + call_kind="chat", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + assert spies["release"] == [] + + assert len(spies["record"]) == 1 + row = spies["record"][0] + assert row["usage_type"] == "podcast_generation" + assert row["thread_id"] == 99 + assert row["search_space_id"] == 42 + assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + + +@pytest.mark.asyncio +async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): + """Premium video-presentation runs that hit a denied reservation must + raise ``QuotaInsufficientError`` *before* the graph runs and must not + emit an audit row (no work happened).""" + from app.services.billable_calls import ( + QuotaInsufficientError, + billable_call, + ) + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult( + allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000 + ), + ) + user_id = uuid4() + + with pytest.raises(QuotaInsufficientError) as exc_info: + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="gpt-5.4", + quota_reserve_micros_override=1_000_000, + usage_type="video_presentation_generation", + thread_id=99, + call_details={"video_presentation_id": 12, "title": "Test Video"}, + ): + pytest.fail("body should not run when reserve is denied") + + err = exc_info.value + assert err.usage_type == "video_presentation_generation" + assert err.remaining_micros == 500_000 + assert spies["reserve"][0]["reserve_micros"] == 1_000_000 + assert spies["finalize"] == [] + assert spies["release"] == [] + assert spies["record"] == [] diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 085740032..b635b4fe8 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/gpt-4o" in model_names assert "openai/dall-e" not in model_names assert "openai/completion-only" not in model_names + + +# --------------------------------------------------------------------------- +# _generate_image_gen_configs / _generate_vision_llm_configs +# --------------------------------------------------------------------------- + + +def test_generate_image_gen_configs_filters_by_image_output(): + """Only models with ``output_modalities`` containing ``image`` are emitted. + Tool-calling and context filters are intentionally NOT applied — image + generation has nothing to do with tool calls and context windows. + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + # Pure image-gen model (small context, no tools — should still emit). + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Multi-modal: text+image output (should still emit). + { + "id": "google/gemini-2.5-flash-image", + "architecture": {"output_modalities": ["text", "image"]}, + "context_length": 1_000_000, + "pricing": {"prompt": "0.000001", "completion": "0.000004"}, + }, + # Pure text model — must NOT emit. + { + "id": "openai/gpt-4o", + "architecture": {"output_modalities": ["text"]}, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + ] + + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openai/gpt-image-1" in model_names + assert "google/gemini-2.5-flash-image" in model_names + assert "openai/gpt-4o" not in model_names + + # Each config must carry ``billing_tier`` for routing in image_generation_routes. + for c in cfgs: + assert c["billing_tier"] in {"free", "premium"} + assert c["provider"] == "OPENROUTER" + assert c[_OPENROUTER_DYNAMIC_MARKER] is True + + +def test_generate_image_gen_configs_assigns_image_id_offset(): + """Image configs use a different id_offset (-20000) so their negative + IDs don't collide with chat configs (-10000) or vision configs (-30000). + """ + from app.services.openrouter_integration_service import ( + _generate_image_gen_configs, + ) + + raw = [ + { + "id": "openai/gpt-image-1", + "architecture": {"output_modalities": ["image"]}, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + } + ] + # Don't pass image_id_offset → use the module default (-20000). + cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) + assert all(c["id"] < -20_000 + 1 for c in cfgs) + assert all(c["id"] > -29_000_000 for c in cfgs) + + +def test_generate_vision_llm_configs_filters_by_image_input_text_output(): + """Vision LLMs must accept image input AND emit text — pure image-gen + (no text out) and text-only (no image in) models are excluded. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + # GPT-4o: vision LLM (image in, text out) — must emit. + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + # Pure image generator — image *output*, no text out. Must NOT emit. + { + "id": "openai/gpt-image-1", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["image"], + }, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Pure text model (no image in). Must NOT emit. + { + "id": "anthropic/claude-3-haiku", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "context_length": 200_000, + "pricing": {"prompt": "0.000001", "completion": "0.000005"}, + }, + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + names = {c["model_name"] for c in cfgs} + assert names == {"openai/gpt-4o"} + + cfg = cfgs[0] + assert cfg["billing_tier"] == "premium" + # Pricing carried inline so pricing_registration can register vision + # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache + # is cleared. + assert cfg["input_cost_per_token"] == pytest.approx(5e-6) + assert cfg["output_cost_per_token"] == pytest.approx(15e-6) + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + + +def test_generate_vision_llm_configs_drops_chat_only_filters(): + """A small-context vision model that doesn't advertise tool calling is + still a valid vision LLM for "describe this image" prompts. The chat + filters (``supports_tool_calling``, ``has_sufficient_context``) must + NOT be applied to vision emission. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + { + "id": "tiny/vision-mini", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": [], # no tools + "context_length": 4_000, # well below MIN_CONTEXT_LENGTH + "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, + } + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + assert len(cfgs) == 1 + assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py new file mode 100644 index 000000000..e97250ff2 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -0,0 +1,447 @@ +"""Pricing registration unit tests. + +The pricing-registration module is what makes ``response_cost`` populate +correctly for OpenRouter dynamic models and operator-defined Azure +deployments — both of which LiteLLM doesn't natively know about. The tests +exercise: + +* The alias generators emit every shape that LiteLLM's cost-callback might + use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``, + ``provider/base_model``, ``provider/model_name``, plus the special + ``azure_openai`` → ``azure`` normalisation). +* ``register_pricing_from_global_configs`` calls ``litellm.register_model`` + with the right alias set and pricing values per provider. +* Configs without a resolvable pair of cost values are skipped — never + registered as zero, since that would override pricing LiteLLM might + already know natively. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Alias generators +# --------------------------------------------------------------------------- + + +def test_openrouter_alias_set_includes_prefixed_and_bare(): + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet") + assert aliases == [ + "openrouter/anthropic/claude-3-5-sonnet", + "anthropic/claude-3-5-sonnet", + ] + + +def test_openrouter_alias_set_dedupes(): + """If the model id is already prefixed with ``openrouter/``, the alias + set must not contain duplicates that would re-register the same key + twice. + """ + from app.services.pricing_registration import _alias_set_for_openrouter + + aliases = _alias_set_for_openrouter("openrouter/foo") + # The bare and prefixed variants compute to the same string here, so we + # at minimum require uniqueness. + assert len(aliases) == len(set(aliases)) + + +def test_yaml_alias_set_for_azure_openai_normalises_to_azure(): + """``azure_openai`` (our YAML provider slug) must register under + ``azure/<name>`` so the LiteLLM Router's deployment-resolution path + (which uses provider ``azure``) finds the pricing too. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="AZURE_OPENAI", + model_name="gpt-5.4", + base_model="gpt-5.4", + ) + assert "gpt-5.4" in aliases + assert "azure_openai/gpt-5.4" in aliases + assert "azure/gpt-5.4" in aliases + + +def test_yaml_alias_set_distinguishes_model_name_and_base_model(): + """When ``model_name`` differs from ``base_model`` (operator labelled a + deployment), both must appear in the alias set since either may surface + in callbacks depending on the call path. + """ + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="OPENAI", + model_name="my-deployment-label", + base_model="gpt-4o", + ) + assert "gpt-4o" in aliases + assert "openai/gpt-4o" in aliases + assert "my-deployment-label" in aliases + assert "openai/my-deployment-label" in aliases + + +def test_yaml_alias_set_omits_provider_prefix_when_provider_blank(): + from app.services.pricing_registration import _alias_set_for_yaml + + aliases = _alias_set_for_yaml( + provider="", + model_name="foo", + base_model="bar", + ) + assert "bar" in aliases + assert "foo" in aliases + assert all("/" not in a for a in aliases) + + +# --------------------------------------------------------------------------- +# register_pricing_from_global_configs +# --------------------------------------------------------------------------- + + +class _RegistrationSpy: + """Captures the dicts passed to ``litellm.register_model``. + + Many calls may go through; we just record them all and let tests assert + against the union. + """ + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def __call__(self, payload: dict[str, Any]) -> None: + self.calls.append(payload) + + @property + def all_keys(self) -> set[str]: + keys: set[str] = set() + for payload in self.calls: + keys.update(payload.keys()) + return keys + + +def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy: + spy = _RegistrationSpy() + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + spy, + raising=False, + ) + return spy + + +def _patch_openrouter_pricing( + monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]] +) -> None: + """Pretend the OpenRouter integration is initialised with ``mapping``.""" + + class _Stub: + def get_raw_pricing(self) -> dict[str, dict[str, str]]: + return mapping + + class _StubService: + @classmethod + def is_initialized(cls) -> bool: + return True + + @classmethod + def get_instance(cls) -> _Stub: + return _Stub() + + monkeypatch.setattr( + "app.services.openrouter_integration_service.OpenRouterIntegrationService", + _StubService, + raising=False, + ) + + +def test_openrouter_models_register_under_aliases(monkeypatch): + """An OpenRouter config whose ``model_name`` is in the cached raw + pricing map is registered under both ``openrouter/X`` and bare ``X``. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys + assert "anthropic/claude-3-5-sonnet" in spy.all_keys + # Costs are float-converted from the raw OpenRouter strings. + payload = spy.calls[0] + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "input_cost_per_token" + ] == pytest.approx(3e-6) + assert payload["openrouter/anthropic/claude-3-5-sonnet"][ + "output_cost_per_token" + ] == pytest.approx(15e-6) + assert ( + payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"] + == "openrouter" + ) + + +def test_yaml_override_registers_under_alias_set(monkeypatch): + """Operator-declared ``input_cost_per_token`` / + ``output_cost_per_token`` on a YAML config registers under every + alias the YAML alias generator produces — including the ``azure/`` + normalisation for ``azure_openai`` providers. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "litellm_params": { + "base_model": "gpt-5.4", + "input_cost_per_token": 2e-6, + "output_cost_per_token": 8e-6, + }, + } + ], + ) + + register_pricing_from_global_configs() + + keys = spy.all_keys + assert "gpt-5.4" in keys + assert "azure_openai/gpt-5.4" in keys + assert "azure/gpt-5.4" in keys + + payload = spy.calls[0] + entry = payload["gpt-5.4"] + assert entry["input_cost_per_token"] == pytest.approx(2e-6) + assert entry["output_cost_per_token"] == pytest.approx(8e-6) + assert entry["litellm_provider"] == "azure" + + +def test_no_override_means_no_registration(monkeypatch): + """A YAML config that *omits* both pricing fields must NOT be registered + — registering as zero would override LiteLLM's native pricing for the + ``base_model`` key (e.g. ``gpt-4o``) and silently make every user's + bill drop to $0. Fail-safe is "skip and warn", not "register zero". + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENAI", + "model_name": "gpt-4o", + "litellm_params": {"base_model": "gpt-4o"}, + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_openrouter_skipped_when_pricing_missing(monkeypatch): + """If the OpenRouter raw-pricing cache doesn't carry an entry for a + configured model (network blip during refresh, model added later, etc.), + we skip it rather than registering zero pricing. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}} + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + } + ], + ) + + register_pricing_from_global_configs() + + assert spy.calls == [] + + +def test_register_continues_after_individual_failure(monkeypatch, caplog): + """A single bad ``register_model`` call (e.g. raising LiteLLM error) + must not abort registration of the remaining configs. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"} + successful_calls: list[dict[str, Any]] = [] + + def _maybe_fail(payload: dict[str, Any]) -> None: + if any(k in failing_keys for k in payload): + raise RuntimeError("boom") + successful_calls.append(payload) + + monkeypatch.setattr( + "app.services.pricing_registration.litellm.register_model", + _maybe_fail, + raising=False, + ) + _patch_openrouter_pricing( + monkeypatch, + { + "anthropic/claude-3-5-sonnet": { + "prompt": "0.000003", + "completion": "0.000015", + } + }, + ) + + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": 1, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-3-5-sonnet", + }, + { + "id": 2, + "provider": "OPENAI", + "model_name": "custom-deployment", + "litellm_params": { + "base_model": "custom-deployment", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 2e-6, + }, + }, + ], + ) + + register_pricing_from_global_configs() + + # The good config still registered. + assert any("custom-deployment" in payload for payload in successful_calls) + + +def test_vision_configs_registered_with_chat_shape(monkeypatch): + """``register_pricing_from_global_configs`` walks + ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision + calls (during indexing) bill correctly. Vision configs use the same + chat-shape token prices, but image-gen pricing is intentionally NOT + registered here (handled via ``response_cost`` in LiteLLM). + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, + ) + + # No chat configs — only vision. Proves the vision walk is a separate + # iteration, not piggy-backed on the chat list. + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "billing_tier": "premium", + "input_cost_per_token": 5e-6, + "output_cost_per_token": 15e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/openai/gpt-4o" in spy.all_keys + payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] + assert payload_value["mode"] == "chat" + assert payload_value["litellm_provider"] == "openrouter" + assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) + assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) + + +def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): + """If the OpenRouter pricing cache misses a vision model (different + catalogue surface), the vision walk falls back to inline + ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash", + "billing_tier": "premium", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 4e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py new file mode 100644 index 000000000..9e35b6f9c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -0,0 +1,157 @@ +"""Unit tests for ``QuotaCheckedVisionLLM``. + +Validates that: + +* Calling ``ainvoke`` routes through ``billable_call`` (premium credit + enforcement) and forwards the inner LLM's response on success. +* The wrapper proxies non-overridden attributes to the inner LLM + (``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output`` + still work without quota gating (they're not used in indexing today). +* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper + bubbles it up — the ETL pipeline catches that and falls back to OCR. +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +class _FakeInnerLLM: + """Stand-in for ``langchain_litellm.ChatLiteLLM``.""" + + def __init__(self, response: Any = "OCR'd content") -> None: + self._response = response + self.ainvoke_calls: list[Any] = [] + + async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: + self.ainvoke_calls.append(input) + return self._response + + def some_other_method(self, x: int) -> int: + return x * 2 + + +@contextlib.asynccontextmanager +async def _passthrough_billable_call(**_kwargs): + """Stand-in for billable_call that always allows the call to run.""" + + class _Acc: + total_cost_micros = 0 + total_prompt_tokens = 0 + total_completion_tokens = 0 + grand_total = 0 + calls: list[Any] = [] + + def per_message_summary(self) -> dict[str, dict[str, int]]: + return {} + + yield _Acc() + + +@pytest.mark.asyncio +async def test_ainvoke_routes_through_billable_call(monkeypatch): + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + captured_kwargs: list[dict[str, Any]] = [] + + @contextlib.asynccontextmanager + async def _spy_billable_call(**kwargs): + captured_kwargs.append(kwargs) + async with _passthrough_billable_call() as acc: + yield acc + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _spy_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM(response="A red apple on a white table") + user_id = uuid4() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=user_id, + search_space_id=99, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + result = await wrapper.ainvoke([{"text": "what is this?"}]) + assert result == "A red apple on a white table" + assert len(inner.ainvoke_calls) == 1 + assert len(captured_kwargs) == 1 + bc_kwargs = captured_kwargs[0] + assert bc_kwargs["user_id"] == user_id + assert bc_kwargs["search_space_id"] == 99 + assert bc_kwargs["billing_tier"] == "premium" + assert bc_kwargs["base_model"] == "openai/gpt-4o" + assert bc_kwargs["quota_reserve_tokens"] == 4000 + assert bc_kwargs["usage_type"] == "vision_extraction" + + +@pytest.mark.asyncio +async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): + from app.services.billable_calls import QuotaInsufficientError + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + @contextlib.asynccontextmanager + async def _denying_billable_call(**_kwargs): + raise QuotaInsufficientError( + usage_type="vision_extraction", + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield # unreachable but required for asynccontextmanager type + + monkeypatch.setattr( + "app.services.quota_checked_vision_llm.billable_call", + _denying_billable_call, + raising=False, + ) + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + with pytest.raises(QuotaInsufficientError): + await wrapper.ainvoke([{"text": "x"}]) + + # Inner LLM never ran on a denied reservation. + assert inner.ainvoke_calls == [] + + +@pytest.mark.asyncio +async def test_proxies_non_overridden_attributes_to_inner(): + """``__getattr__`` forwards anything not on the proxy itself, so any + method we didn't explicitly override (``invoke``, ``astream``, + ``with_structured_output``, etc.) still works — just without quota + gating, which is fine because the indexer only ever calls ainvoke. + """ + from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + + inner = _FakeInnerLLM() + wrapper = QuotaCheckedVisionLLM( + inner, + user_id=uuid4(), + search_space_id=1, + billing_tier="premium", + base_model="openai/gpt-4o", + quota_reserve_tokens=4000, + ) + + # ``some_other_method`` is on the inner only. + assert wrapper.some_other_method(7) == 14 diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py new file mode 100644 index 000000000..63681828d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py @@ -0,0 +1,515 @@ +"""Cost-based premium quota unit tests. + +Covers the USD-micro behaviour added in migration 140: + +* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all + calls in a turn — used as the debit amount when ``agent_config.is_premium`` + is true, regardless of which underlying model produced each call. This + preserves the prior "premium turn → all calls in turn count" rule from the + token-based system. +* ``estimate_call_reserve_micros`` scales linearly with model pricing, + clamps to a sane floor when pricing is unknown, and respects the + ``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry + can't lock the whole balance on one call. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# TurnTokenAccumulator — premium-turn debit semantics +# --------------------------------------------------------------------------- + + +def test_total_cost_micros_sums_premium_and_free_calls(): + """A premium turn that also called a free sub-agent debits the union. + + The plan deliberately preserved the existing "premium turn → all calls + count" behaviour because per-call premium filtering relied on + ``LLMRouterService._premium_model_strings`` which only covers router-pool + deployments. ``total_cost_micros`` therefore must include free-model + calls (whose ``cost_micros`` is typically ``0``) as well as the premium + call's actual provider cost. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + # Premium model (e.g. claude-opus): non-zero cost. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=1200, + completion_tokens=400, + total_tokens=1600, + cost_micros=12_345, + ) + # Free sub-agent (e.g. title-gen on a free model): zero cost. + acc.add( + model="gpt-4o-mini", + prompt_tokens=120, + completion_tokens=20, + total_tokens=140, + cost_micros=0, + ) + # A second premium-priced call within the same turn. + acc.add( + model="anthropic/claude-3-5-sonnet", + prompt_tokens=800, + completion_tokens=200, + total_tokens=1000, + cost_micros=7_500, + ) + + assert acc.total_cost_micros == 12_345 + 0 + 7_500 + # Token totals stay correct so the FE display path still works. + assert acc.grand_total == 1600 + 140 + 1000 + + +def test_total_cost_micros_zero_when_no_calls(): + """An empty accumulator must report zero cost (no division-by-zero, no None).""" + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + assert acc.total_cost_micros == 0 + assert acc.grand_total == 0 + + +def test_per_message_summary_groups_cost_by_model(): + """``per_message_summary`` must accumulate ``cost_micros`` per model so the + SSE ``model_breakdown`` payload reports actual USD spend per provider. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cost_micros=4_000, + ) + acc.add( + model="claude-3-5-sonnet", + prompt_tokens=200, + completion_tokens=100, + total_tokens=300, + cost_micros=8_000, + ) + acc.add( + model="gpt-4o-mini", + prompt_tokens=50, + completion_tokens=10, + total_tokens=60, + cost_micros=200, + ) + + summary = acc.per_message_summary() + assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000 + assert summary["claude-3-5-sonnet"]["total_tokens"] == 450 + assert summary["gpt-4o-mini"]["cost_micros"] == 200 + + +def test_serialized_calls_includes_cost_micros(): + """``serialized_calls`` is what flows into the SSE ``call_details`` + payload; cost_micros must be present on each entry so the FE message-info + dropdown can render per-call USD. + """ + from app.services.token_tracking_service import TurnTokenAccumulator + + acc = TurnTokenAccumulator() + acc.add( + model="m", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=42, + ) + serialized = acc.serialized_calls() + assert serialized == [ + { + "model": "m", + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + "cost_micros": 42, + "call_kind": "chat", + } + ] + + +# --------------------------------------------------------------------------- +# estimate_call_reserve_micros — sizing and clamping +# --------------------------------------------------------------------------- + + +def test_reserve_returns_floor_when_model_unknown(monkeypatch): + """If LiteLLM doesn't know the model, ``get_model_info`` raises and the + helper falls back to the 100-micro floor — small enough that a user with + $0.0001 left can still send a tiny request, but non-zero so we still gate + against an empty balance. + """ + import litellm + + from app.services import token_quota_service + + def _raise(_name): + raise KeyError("unknown") + + monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="nonexistent-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + assert micros == 100 + + +def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch): + """LiteLLM may *return* a model with both cost-per-token fields at 0 + (pricing not yet registered). The helper must not multiply 0 x tokens + and end up reserving 0 — it must clamp to the floor. + """ + import litellm + + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0}, + raising=False, + ) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="some-pending-model", + quota_reserve_tokens=4000, + ) + assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS + + +def test_reserve_scales_with_model_cost(monkeypatch): + """Claude-Opus-priced model with 4000 reserve_tokens reserves + ~$0.36 = 360_000 micros. Critically this must NOT be clamped down to + some small artificial cap — that was the bug the plan called out. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 15e-6, + "output_cost_per_token": 75e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="claude-3-opus", + quota_reserve_tokens=4000, + ) + # 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros. + assert micros == 360_000 + + +def test_reserve_clamps_to_max_ceiling(monkeypatch): + """A misconfigured "$1000 / M" model with 4000 reserve_tokens would + nominally compute to $4 = 4_000_000 micros. The ceiling + ``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry + can't lock the user's whole balance on one call. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-3, + "output_cost_per_token": 0, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + micros = token_quota_service.estimate_call_reserve_micros( + base_model="oops-misconfigured", + quota_reserve_tokens=4000, + ) + assert micros == 1_000_000 + + +def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch): + """Per-config ``quota_reserve_tokens`` is optional; when ``None`` or + zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL`` + so anonymous-style configs still reserve the operator-tunable default. + """ + import litellm + + from app.config import config + from app.services import token_quota_service + + monkeypatch.setattr( + litellm, + "get_model_info", + lambda _name: { + "input_cost_per_token": 1e-6, + "output_cost_per_token": 1e-6, + }, + raising=False, + ) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False) + monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False) + + # 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=None + ) + == 4000 + ) + assert ( + token_quota_service.estimate_call_reserve_micros( + base_model="cheap", quota_reserve_tokens=0 + ) + == 4000 + ) + + +# --------------------------------------------------------------------------- +# TokenTrackingCallback — image vs chat usage shape +# --------------------------------------------------------------------------- + + +class _FakeImageUsage: + """Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape).""" + + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: int | None = None, + ) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + if total_tokens is not None: + self.total_tokens = total_tokens + + +class _FakeImageResponse: + """Mimics LiteLLM's ``ImageResponse`` — same name so the callback's + ``type(...).__name__`` probe routes to the image branch. + """ + + def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None): + self.usage = usage + if response_cost is not None: + self._hidden_params = {"response_cost": response_cost} + + +# Re-tag the helper class as ``ImageResponse`` for the type-name probe in +# the callback. We can't simply name the class ``ImageResponse`` because +# the test runner sometimes imports test modules in surprising ways and +# we want to be explicit. +_FakeImageResponse.__name__ = "ImageResponse" + + +class _FakeChatUsage: + def __init__(self, prompt: int, completion: int): + self.prompt_tokens = prompt + self.completion_tokens = completion + self.total_tokens = prompt + completion + + +class _FakeChatResponse: + def __init__(self, usage: _FakeChatUsage): + self.usage = usage + + +@pytest.mark.asyncio +async def test_callback_reads_image_usage_input_output_tokens(): + """``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens`` + for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT + prompt_tokens/completion_tokens which is the chat shape. + """ + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50), + response_cost=0.04, # $0.04 per image + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04}, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 42 + assert call.completion_tokens == 8 + assert call.total_tokens == 50 + # 0.04 USD = 40_000 micros + assert call.cost_micros == 40_000 + assert call.call_kind == "image_generation" + + +@pytest.mark.asyncio +async def test_callback_chat_path_unchanged(): + """Chat responses must still read prompt_tokens/completion_tokens.""" + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + cb = TokenTrackingCallback() + response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30)) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={ + "model": "openrouter/anthropic/claude-3-5-sonnet", + "response_cost": 0.0036, + }, + response_obj=response, + start_time=None, + end_time=None, + ) + assert len(acc.calls) == 1 + call = acc.calls[0] + assert call.prompt_tokens == 120 + assert call.completion_tokens == 30 + assert call.total_tokens == 150 + assert call.cost_micros == 3_600 + assert call.call_kind == "chat" + + +@pytest.mark.asyncio +async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch): + """When OpenRouter omits ``usage.cost`` LiteLLM's + ``default_image_cost_calculator`` raises. The defensive image branch in + ``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is + chat-shaped and would raise too) — it returns 0 with a WARNING log. + """ + import litellm + + from app.services.token_tracking_service import ( + TokenTrackingCallback, + scoped_turn, + ) + + # Force completion_cost to raise the same way OpenRouter image-gen fails. + def _boom(*_args, **_kwargs): + raise ValueError("model_cost: missing entry for openrouter image model") + + monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False) + + # And make sure cost_per_token is NEVER called for the image path — + # if it were, our ``is_image=True`` branch is broken. + cost_per_token_calls: list = [] + + def _record_cost_per_token(**kwargs): + cost_per_token_calls.append(kwargs) + return (0.0, 0.0) + + monkeypatch.setattr( + litellm, "cost_per_token", _record_cost_per_token, raising=False + ) + + cb = TokenTrackingCallback() + response = _FakeImageResponse( + usage=_FakeImageUsage(input_tokens=7, output_tokens=0) + ) + + async with scoped_turn() as acc: + await cb.async_log_success_event( + kwargs={"model": "openrouter/google/gemini-2.5-flash-image"}, + response_obj=response, + start_time=None, + end_time=None, + ) + + assert len(acc.calls) == 1 + assert acc.calls[0].cost_micros == 0 + assert acc.calls[0].call_kind == "image_generation" + # The image branch must short-circuit before cost_per_token. + assert cost_per_token_calls == [] + + +# --------------------------------------------------------------------------- +# scoped_turn — ContextVar reset semantics (issue B) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scoped_turn_restores_outer_accumulator(): + """``scoped_turn`` must restore the previous ContextVar value on exit + so a per-call wrapper inside an outer chat turn doesn't leak its + accumulator outward (which would cause double-debit at chat-turn exit). + """ + from app.services.token_tracking_service import ( + get_current_accumulator, + scoped_turn, + start_turn, + ) + + outer = start_turn() + assert get_current_accumulator() is outer + + async with scoped_turn() as inner: + assert get_current_accumulator() is inner + assert inner is not outer + inner.add( + model="x", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost_micros=5, + ) + + # After exit the outer accumulator is restored unchanged. + assert get_current_accumulator() is outer + assert outer.total_cost_micros == 0 + assert len(outer.calls) == 0 + # The inner accumulator captured the call but didn't bleed into outer. + assert inner.total_cost_micros == 5 + + +@pytest.mark.asyncio +async def test_scoped_turn_resets_to_none_when_no_outer(): + """Running ``scoped_turn`` outside any chat turn (e.g. a background + indexing job) must leave the ContextVar at ``None`` on exit so the + next *unrelated* request starts clean. + """ + from app.services.token_tracking_service import ( + _turn_accumulator, + get_current_accumulator, + scoped_turn, + ) + + # ContextVar default is None for a fresh test isolated context. We + # simulate "no outer" explicitly to be robust against test order. + token = _turn_accumulator.set(None) + try: + assert get_current_accumulator() is None + async with scoped_turn() as acc: + assert get_current_accumulator() is acc + assert get_current_accumulator() is None + finally: + _turn_accumulator.reset(token) diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py new file mode 100644 index 000000000..38d6ba2ca --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -0,0 +1,325 @@ +"""Unit tests for podcast Celery task billing integration. + +Validates ``_generate_content_podcast`` correctly wraps +``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the +search-space owner's billing decision, and degrades cleanly when the +resolver fails or premium credit is exhausted. + +Coverage: + +* Happy-path free config: resolver → ``billable_call`` enters with + ``usage_type='podcast_generation'`` and the configured reserve override, + graph runs, podcast row flips to ``READY``. +* Happy-path premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → + graph is *not* invoked, podcast row flips to ``FAILED``, return dict + carries ``reason='premium_quota_exhausted'``. +* Resolver failure: ``ValueError`` from the resolver → podcast row flips + to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, podcast): + self._podcast = podcast + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._podcast) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: + """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily + inside helpers keeps this fixture cheap.""" + return SimpleNamespace( + id=podcast_id, + title="Test Podcast", + thread_id=thread_id, + status=None, + podcast_transcript=None, + file_location=None, + ) + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + """Stand-in for ``billable_call`` that records its kwargs and yields a + no-op accumulator-shaped object.""" + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover — for grammar only + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + """Happy path: free billing tier still wraps the graph call so the + audit row is recorded. Verifies kwargs threading.""" + from app.config import config as app_config + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=7, thread_id=99) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 555 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return { + "podcast_transcript": [ + SimpleNamespace(speaker_id=0, dialog="Hi"), + SimpleNamespace(speaker_id=1, dialog="Hello"), + ], + "final_podcast_file_path": "/tmp/podcast.wav", + } + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hello world", + search_space_id=555, + user_prompt="make it short", + ) + + assert result["status"] == "ready" + assert result["podcast_id"] == 7 + assert podcast.status == PodcastStatus.READY + assert podcast.file_location == "/tmp/podcast.wav" + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 555 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "podcast_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS + ) + assert call["thread_id"] == 99 + assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + """Premium resolution flows through to ``billable_call`` so the + reserve/finalize path triggers.""" + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast() + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): + """When ``billable_call`` denies the reservation, the graph never + runs and the podcast row flips to FAILED with the documented reason + code.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=8) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=8, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 8, + "reason": "premium_quota_exhausted", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] # Graph never ran on denied reservation. + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_podcast_failed(monkeypatch): + """If the resolver raises (e.g. search-space deleted), the task fails + cleanly without invoking the graph.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=9) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 555 not found") + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=9, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 9, + "reason": "billing_resolution_failed", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py new file mode 100644 index 000000000..671f57ae4 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -0,0 +1,330 @@ +"""Unit tests for video-presentation Celery task billing integration. + +Mirrors ``test_podcast_billing.py`` for the video-presentation task. +Validates the same wrap-graph-in-billable_call pattern and ensures the +larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is +threaded through. + +Coverage: + +* Free config: graph runs, ``billable_call`` invoked with the video + reserve override. +* Premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: graph not invoked, row → FAILED, reason code surfaced. +* Resolver failure: row → FAILED with ``billing_resolution_failed``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, video): + self._video = video + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._video) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace: + return SimpleNamespace( + id=video_id, + title="Test Presentation", + thread_id=thread_id, + status=None, + slides=None, + scene_codes=None, + ) + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + from app.config import config as app_config + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=11, thread_id=99) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 777 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result["status"] == "ready" + assert result["video_presentation_id"] == 11 + assert video.status == VideoPresentationStatus.READY + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 777 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "video_presentation_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS + ) + assert call["thread_id"] == 99 + assert call["call_details"] == { + "video_presentation_id": 11, + "title": "Test Presentation", + } + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video() + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + await video_presentation_tasks._generate_video_presentation( + video_presentation_id=11, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=12) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, "billable_call", _denying_billable_call + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=12, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 12, + "reason": "premium_quota_exhausted", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=13) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 777 not found") + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _failing_resolver, + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=13, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 13, + "reason": "billing_resolution_failed", + } + assert video.status == VideoPresentationStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index 8d9ed5cb1..3ddd5195f 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -127,7 +127,7 @@ const FAQ_ITEMS = [ { question: "What happens after I use my free tokens?", answer: - "After your free tokens, create a free SurfSense account to unlock 3 million more premium tokens. Additional tokens can be purchased at $1 per million. Non-premium models remain unlimited for registered users.", + "After your free tokens, create a free SurfSense account to unlock $5 of premium credit. Additional credit can be topped up at $1 for $1 of credit, billed at the actual provider cost. Non-premium models remain unlimited for registered users.", }, { question: "Is Claude AI available without login?", @@ -329,7 +329,7 @@ export default async function FreeHubPage() { <section className="max-w-3xl mx-auto text-center"> <h2 className="text-2xl font-bold mb-3">Want More Features?</h2> <p className="text-muted-foreground mb-6 leading-relaxed"> - Create a free SurfSense account to unlock 3 million tokens, document uploads with + Create a free SurfSense account to unlock $5 of premium credit, document uploads with citations, team collaboration, and integrations with Slack, Google Drive, Notion, and 30+ more tools. </p> diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx index 6ad9435bf..6f332be70 100644 --- a/surfsense_web/app/(home)/pricing/page.tsx +++ b/surfsense_web/app/(home)/pricing/page.tsx @@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav"; export const metadata: Metadata = { title: "Pricing | SurfSense - Free AI Search Plans", description: - "Explore SurfSense plans and pricing. Start free with 500 pages & 3M premium tokens. Use ChatGPT, Claude AI, and premium AI models. Pay-as-you-go tokens at $1 per million.", + "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.", alternates: { canonical: "https://surfsense.com/pricing", }, diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx index 3017160e1..0c5662712 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx @@ -8,7 +8,7 @@ import { cn } from "@/lib/utils"; const TABS = [ { id: "pages", label: "Pages" }, - { id: "tokens", label: "Premium Tokens" }, + { id: "tokens", label: "Premium Credit" }, ] as const; type TabId = (typeof TABS)[number]["id"]; diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx index 2b7422f80..cf73b5eba 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx @@ -28,6 +28,12 @@ type UnifiedPurchase = { kind: PurchaseKind; created_at: string; status: PagePurchaseStatus; + /** + * Granted units. Interpretation depends on ``kind``: + * - ``"pages"`` — integer number of indexed pages. + * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00). + * The ``Granted`` column formats accordingly. + */ granted: number; amount_total: number | null; currency: string | null; @@ -58,7 +64,7 @@ const KIND_META: Record< iconClass: "text-sky-500", }, tokens: { - label: "Premium Tokens", + label: "Premium Credit", icon: Coins, iconClass: "text-amber-500", }, @@ -97,12 +103,25 @@ function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase { kind: "tokens", created_at: p.created_at, status: p.status, - granted: p.tokens_granted, + granted: p.credit_micros_granted, amount_total: p.amount_total, currency: p.currency, }; } +function formatGranted(p: UnifiedPurchase): string { + if (p.kind === "tokens") { + const dollars = p.granted / 1_000_000; + // Premium credit packs are always whole dollars at the moment, but + // future fractional grants (refunds, partial top-ups) shouldn't + // silently round to "$0". + if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`; + if (dollars > 0) return `$${dollars.toFixed(3)} of credit`; + return "$0 of credit"; + } + return p.granted.toLocaleString(); +} + export function PurchaseHistoryContent() { const results = useQueries({ queries: [ @@ -143,7 +162,7 @@ export function PurchaseHistoryContent() { <ReceiptText className="h-8 w-8 text-muted-foreground" /> <p className="text-sm font-medium">No purchases yet</p> <p className="text-xs text-muted-foreground"> - Your page and premium token purchases will appear here after checkout. + Your page and premium credit purchases will appear here after checkout. </p> </div> ); @@ -177,7 +196,7 @@ export function PurchaseHistoryContent() { </div> </TableCell> <TableCell className="text-right tabular-nums text-sm"> - {p.granted.toLocaleString()} + {formatGranted(p)} </TableCell> <TableCell className="text-right tabular-nums text-sm"> {formatAmount(p.amount_total, p.currency)} diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index a59811324..4b6717440 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -8,9 +8,9 @@ const userQueryFn = () => userApiService.getMe(); export const currentUserAtom = atomWithQuery(() => { return { queryKey: USER_QUERY_KEY, - // Live-changing numeric fields (pages_*, premium_tokens_*) are now - // pushed via Zero (queries.user.me()), so /users/me only needs to - // fire once per session for the static profile fields. + // Live-changing numeric fields (pages_*, premium_credit_micros_*) + // are now pushed via Zero (queries.user.me()), so /users/me only + // needs to fire once per session for the static profile fields. staleTime: Infinity, enabled: !!getBearerToken(), queryFn: userQueryFn, diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 711bb2fe2..ffb0e4dc8 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -399,6 +399,19 @@ function formatMessageDate(date: Date): string { }); } +/** + * Format provider USD cost (in micro-USD) for inline display next to a + * token count. Falls back to ``"<$0.001"`` for sub-tenth-of-a-cent + * costs so a real-but-tiny figure doesn't render as ``$0.000``. + */ +function formatTurnCost(micros: number): string { + const dollars = micros / 1_000_000; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + if (dollars >= 0.01) return `$${dollars.toFixed(3)}`; + if (dollars > 0) return "<$0.001"; + return "$0"; +} + const MessageInfoDropdown: FC = () => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); @@ -451,6 +464,7 @@ const MessageInfoDropdown: FC = () => { {models.length > 0 ? ( models.map(([model, counts]) => { const { name, icon } = resolveModel(model); + const costMicros = counts.cost_micros; return ( <ActionBarMorePrimitive.Item key={model} @@ -463,6 +477,9 @@ const MessageInfoDropdown: FC = () => { </span> <span className="text-xs text-muted-foreground"> {counts.total_tokens.toLocaleString()} tokens + {costMicros && costMicros > 0 + ? ` · ${formatTurnCost(costMicros)}` + : ""} </span> </ActionBarMorePrimitive.Item> ); @@ -474,6 +491,9 @@ const MessageInfoDropdown: FC = () => { > <span className="text-xs text-muted-foreground"> {usage.total_tokens.toLocaleString()} tokens + {usage.cost_micros && usage.cost_micros > 0 + ? ` · ${formatTurnCost(usage.cost_micros)}` + : ""} </span> </ActionBarMorePrimitive.Item> )} diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index b3f71ab21..dd80bcac3 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -13,13 +13,30 @@ export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Populated from LiteLLM's response_cost on + * the backend. Optional because pre-cost-credits messages persisted + * before the migration won't have it. + */ + cost_micros?: number; usage?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; } diff --git a/surfsense_web/components/free-chat/quota-warning-banner.tsx b/surfsense_web/components/free-chat/quota-warning-banner.tsx index 3bfedf1b3..e013a64a8 100644 --- a/surfsense_web/components/free-chat/quota-warning-banner.tsx +++ b/surfsense_web/components/free-chat/quota-warning-banner.tsx @@ -40,7 +40,7 @@ export function QuotaWarningBanner({ </p> <p className="text-xs text-red-600 dark:text-red-300"> You've used all {limit.toLocaleString()} free tokens. Create a free account to - get 3 million tokens and access to all models. + get $5 of premium credit and access to all models. </p> <Link href="/register" @@ -69,7 +69,7 @@ export function QuotaWarningBanner({ <Link href="/register" className="font-medium underline hover:no-underline"> Create an account </Link>{" "} - for 5M free tokens. + for $5 of premium credit. </p> <button type="button" diff --git a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx index a3f028858..983672d0b 100644 --- a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx +++ b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx @@ -5,6 +5,14 @@ import { Progress } from "@/components/ui/progress"; import { useIsAnonymous } from "@/contexts/anonymous-mode"; import { queries } from "@/zero/queries"; +/** + * Premium credit balance shown in the sidebar. + * + * Values come from Zero (live-replicated from Postgres) and are stored as + * integer micro-USD (1_000_000 == $1.00). We render in dollars because + * users top up at $1/pack and the credit gets debited at actual provider + * cost. + */ export function PremiumTokenUsageDisplay() { const isAnonymous = useIsAnonymous(); const [me] = useQuery(queries.user.me({})); @@ -12,21 +20,26 @@ export function PremiumTokenUsageDisplay() { if (isAnonymous || !me) return null; const usagePercentage = Math.min( - (me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100, + (me.premiumCreditMicrosUsed / Math.max(me.premiumCreditMicrosLimit, 1)) * 100, 100 ); - const formatTokens = (n: number) => { - if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(1)}M`; - if (n >= 1_000) return `${(n / 1_000).toFixed(0)}K`; - return n.toLocaleString(); + const formatUsd = (micros: number) => { + const dollars = micros / 1_000_000; + if (dollars >= 100) return `$${dollars.toFixed(0)}`; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + // Sub-dollar balances need extra precision so the bar still tells the + // user what's left ("$0.04 of credit") instead of rounding to "$0". + if (dollars > 0) return `$${dollars.toFixed(3)}`; + return "$0"; }; return ( <div className="space-y-1.5"> <div className="flex justify-between items-center text-xs"> <span className="text-muted-foreground"> - {formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens + {formatUsd(me.premiumCreditMicrosUsed)} / {formatUsd(me.premiumCreditMicrosLimit)} of + credit </span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span> </div> diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 175cae4ab..127b79167 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -12,11 +12,11 @@ const demoPlans = [ price: "0", yearlyPrice: "0", period: "", - billingText: "500 pages + 3M premium tokens included", + billingText: "500 pages + $5 of premium credit included", features: [ "Self Hostable", "500 pages included to start", - "3 million premium tokens to start", + "$5 of premium credit to start, billed at provider cost", "Includes access to OpenAI text, audio and image models", "Realtime Collaborative Group Chats with teammates", "Community support on Discord", @@ -35,7 +35,7 @@ const demoPlans = [ features: [ "Everything in Free", "Buy 1,000-page packs at $1 each", - "Buy 1M premium token packs at $1 each", + "Top up premium credit at $1 per $1 of credit, billed at provider cost", "Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter", "Priority support on Discord", ], @@ -129,27 +129,27 @@ const faqData: FAQSection[] = [ ], }, { - title: "Premium Tokens", + title: "Premium Credit", items: [ { - question: 'What are "premium tokens"?', + question: 'What is "premium credit"?', answer: - "Premium tokens are the billing unit for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request consumes tokens based on the length of your conversation. Non-premium models (such as free-tier models available without login) do not consume premium tokens.", + "Premium credit is your USD balance for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request debits the actual USD cost the provider charges, so cheap and expensive models bill proportionally. Non-premium models (such as the free-tier models available without login) don't touch your premium credit.", }, { - question: "How many premium tokens do I get for free?", + question: "How much premium credit do I get for free?", answer: - "Every registered SurfSense account starts with 3 million premium tokens at no cost. Anonymous users (no login) get 500,000 free tokens across all models. Once your free tokens are used up, you can purchase more at any time.", + "Every registered SurfSense account starts with $5 of premium credit at no cost. Anonymous users (no login) get 500,000 free tokens across all free models. Once your free credit runs out, you can top up at any time.", }, { - question: "How does purchasing premium tokens work?", + question: "How does buying premium credit work?", answer: - "Just like pages, there's no subscription. You buy 1-million-token packs at $1 each whenever you need more. Purchased tokens are added to your account immediately. You can buy up to 100 packs at a time.", + "Just like pages, there's no subscription. Top-ups buy $1 of credit for $1 — every cent you pay is spent at provider cost, no markup. Purchased credit is added to your account immediately. You can buy up to $100 at a time.", }, { - question: "What happens if I run out of premium tokens?", + question: "What happens if I run out of premium credit?", answer: - "When your premium token balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you purchase more tokens. You can always switch to non-premium models which don't consume premium tokens.", + "When your premium credit balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you top up. You can always switch to non-premium models, which don't touch your premium credit.", }, ], }, @@ -157,9 +157,9 @@ const faqData: FAQSection[] = [ title: "Self-Hosting", items: [ { - question: "Can I self-host SurfSense with unlimited pages and tokens?", + question: "Can I self-host SurfSense with unlimited pages and credit?", answer: - "Yes! When self-hosting, you have full control over your page and token limits. The default self-hosted setup gives you effectively unlimited pages and tokens, so you can index as much data and use as many AI queries as your infrastructure supports.", + "Yes! When self-hosting, you have full control over your page and premium-credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credit, so you can index as much data and use as many AI queries as your infrastructure supports.", }, ], }, @@ -250,8 +250,8 @@ function PricingFAQ() { Frequently Asked Questions </h2> <p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground"> - Everything you need to know about SurfSense pages, premium tokens, and billing. Can't - find what you need? Reach out at{" "} + Everything you need to know about SurfSense pages, premium credit, and billing. + Can't find what you need? Reach out at{" "} <a href="mailto:rohan@surfsense.com" className="text-blue-500 underline"> rohan@surfsense.com </a> @@ -335,7 +335,7 @@ function PricingBasic() { <Pricing plans={demoPlans} title="SurfSense Pricing" - description="Start free with 500 pages & 3M premium tokens. Pay as you go." + description="Start free with 500 pages & $5 of premium credit. Pay as you go, billed at provider cost." /> <PricingFAQ /> </> diff --git a/surfsense_web/components/settings/buy-tokens-content.tsx b/surfsense_web/components/settings/buy-tokens-content.tsx index e7fac4255..79a1b4943 100644 --- a/surfsense_web/components/settings/buy-tokens-content.tsx +++ b/surfsense_web/components/settings/buy-tokens-content.tsx @@ -14,10 +14,23 @@ import { AppError } from "@/lib/error"; import { cn } from "@/lib/utils"; import { queries } from "@/zero/queries"; -const TOKEN_PACK_SIZE = 1_000_000; +// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the +// backend. Premium turns are debited at the actual provider cost +// reported by LiteLLM, so $1 of credit always buys $1 of provider +// usage at cost. +const CREDIT_PER_PACK_MICROS = 1_000_000; const PRICE_PER_PACK_USD = 1; const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const; +const formatUsd = (micros: number, options?: { compact?: boolean }) => { + const dollars = micros / 1_000_000; + if (options?.compact && dollars >= 1) return `$${dollars.toFixed(2)}`; + if (dollars >= 100) return `$${dollars.toFixed(0)}`; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + if (dollars > 0) return `$${dollars.toFixed(3)}`; + return "$0"; +}; + export function BuyTokensContent() { const params = useParams(); const searchSpaceId = Number(params?.search_space_id); @@ -29,7 +42,7 @@ export function BuyTokensContent() { queryFn: () => stripeApiService.getTokenStatus(), }); - // Live per-user usage via Zero. + // Live per-user balance via Zero. const [me] = useZeroQuery(queries.user.me({})); const purchaseMutation = useMutation({ @@ -46,44 +59,46 @@ export function BuyTokensContent() { }, }); - const totalTokens = quantity * TOKEN_PACK_SIZE; + const totalCreditMicros = quantity * CREDIT_PER_PACK_MICROS; const totalPrice = quantity * PRICE_PER_PACK_USD; if (tokenStatus && !tokenStatus.token_buying_enabled) { return ( <div className="w-full space-y-3 text-center"> - <h2 className="text-xl font-bold tracking-tight">Buy Premium Tokens</h2> + <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2> <p className="text-sm text-muted-foreground"> - Token purchases are temporarily unavailable. + Credit purchases are temporarily unavailable. </p> </div> ); } - const used = me?.premiumTokensUsed ?? 0; - const limit = me?.premiumTokensLimit ?? 0; - // Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)). + const used = me?.premiumCreditMicrosUsed ?? 0; + const limit = me?.premiumCreditMicrosLimit ?? 0; + // Mirrors the backend formula in stripe_routes.py (max(0, limit - used)). const remaining = Math.max(0, limit - used); const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0; return ( <div className="w-full space-y-5"> <div className="text-center"> - <h2 className="text-xl font-bold tracking-tight">Buy Premium Tokens</h2> - <p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p> + <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2> + <p className="mt-1 text-sm text-muted-foreground"> + $1 buys $1 of credit, billed at provider cost + </p> </div> {me && ( <div className="rounded-lg border bg-muted/20 p-3 space-y-1.5"> <div className="flex justify-between items-center text-xs"> <span className="text-muted-foreground"> - {used.toLocaleString()} / {limit.toLocaleString()} premium tokens + {formatUsd(used)} / {formatUsd(limit)} of credit </span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span> </div> <Progress value={usagePercentage} className="h-1.5" /> <p className="text-[11px] text-muted-foreground"> - {remaining.toLocaleString()} tokens remaining + {formatUsd(remaining)} of credit remaining </p> </div> )} @@ -99,7 +114,7 @@ export function BuyTokensContent() { <Minus className="h-3.5 w-3.5" /> </button> <span className="min-w-32 text-center text-lg font-semibold tabular-nums"> - {(totalTokens / 1_000_000).toFixed(0)}M tokens + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit </span> <button type="button" @@ -125,14 +140,14 @@ export function BuyTokensContent() { : "border-border hover:border-purple-500/40 hover:bg-muted/40" )} > - {m}M + ${m} </button> ))} </div> <div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2"> <span className="text-sm font-medium tabular-nums"> - {(totalTokens / 1_000_000).toFixed(0)}M premium tokens + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit </span> <span className="text-sm font-semibold tabular-nums">${totalPrice}</span> </div> @@ -149,7 +164,7 @@ export function BuyTokensContent() { </> ) : ( <> - Buy {(totalTokens / 1_000_000).toFixed(0)}M Tokens for ${totalPrice} + Buy ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit for ${totalPrice} </> )} </Button> diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index f5f128f80..ced97464e 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -190,7 +190,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { ? "model" : "models"} </span>{" "} - available from your administrator. + available from your administrator.{" "} + {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()} </p> </AlertDescription> </Alert> diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index e21dc9028..a2eb6a22e 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -371,6 +371,17 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { </SelectLabel> {roleGlobalConfigs.map((config) => { const isAuto = "is_auto_mode" in config && config.is_auto_mode; + // Read billing_tier from the global config; default to "free" + // for legacy YAMLs / Auto stub. Premium gets a purple badge, + // free gets an emerald one — same palette as the chat + // model selector so the meaning is consistent across + // surfaces (issues E, H). + const billingTier = + ("billing_tier" in config && + typeof config.billing_tier === "string" && + config.billing_tier) || + "free"; + const isPremium = billingTier === "premium"; return ( <SelectItem key={config.id} @@ -382,13 +393,27 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { <span className="truncate text-xs md:text-sm"> {config.name} </span> - {isAuto && ( + {isAuto ? ( <Badge variant="secondary" className="text-[8px] md:text-[9px] shrink-0 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 [[data-slot=select-trigger]_&]:hidden" > Recommended </Badge> + ) : isPremium ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0 [[data-slot=select-trigger]_&]:hidden" + > + Premium + </Badge> + ) : ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0 [[data-slot=select-trigger]_&]:hidden" + > + Free + </Badge> )} </div> </SelectItem> diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 8abfa4774..886d71008 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -191,7 +191,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { ? "model" : "models"} </span>{" "} - available from your administrator. + available from your administrator.{" "} + {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()} </p> </AlertDescription> </Alert> diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx index fad64fa9f..790e5c00e 100644 --- a/surfsense_web/contexts/login-gate.tsx +++ b/surfsense_web/contexts/login-gate.tsx @@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) { <DialogHeader> <DialogTitle>Create a free account to {feature}</DialogTitle> <DialogDescription> - Get 3 million tokens, save chat history, upload documents, use all AI tools, and - connect 30+ integrations. + Get $5 of premium credit, save chat history, upload documents, use all AI tools, + and connect 30+ integrations. </DialogDescription> </DialogHeader> <DialogFooter className="flex flex-col gap-2 sm:flex-row"> diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index ecffc573e..2d6b70eda 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -258,6 +258,8 @@ export const globalImageGenConfig = z.object({ litellm_params: z.record(z.string(), z.any()).nullable().optional(), is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + quota_reserve_micros: z.number().nullable().optional(), }); export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); @@ -338,6 +340,10 @@ export const globalVisionLLMConfig = z.object({ litellm_params: z.record(z.string(), z.any()).nullable().optional(), is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + quota_reserve_tokens: z.number().nullable().optional(), + input_cost_per_token: z.number().nullable().optional(), + output_cost_per_token: z.number().nullable().optional(), }); export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig); diff --git a/surfsense_web/contracts/types/stripe.types.ts b/surfsense_web/contracts/types/stripe.types.ts index c8b017044..251f7a176 100644 --- a/surfsense_web/contracts/types/stripe.types.ts +++ b/surfsense_web/contracts/types/stripe.types.ts @@ -32,7 +32,7 @@ export const getPagePurchasesResponse = z.object({ purchases: z.array(pagePurchase), }); -// Premium token purchases +// Premium credit purchases export const createTokenCheckoutSessionRequest = z.object({ quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), @@ -42,11 +42,16 @@ export const createTokenCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); +// Premium credit balance + purchase records. +// +// The unit is integer micro-USD (1_000_000 == $1.00). The schema names +// kept the ``Token`` prefix for API back-compat with pinned clients; +// the field names below are authoritative. export const tokenStripeStatusResponse = z.object({ token_buying_enabled: z.boolean(), - premium_tokens_used: z.number().default(0), - premium_tokens_limit: z.number().default(0), - premium_tokens_remaining: z.number().default(0), + premium_credit_micros_used: z.number().default(0), + premium_credit_micros_limit: z.number().default(0), + premium_credit_micros_remaining: z.number().default(0), }); export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum; @@ -56,7 +61,7 @@ export const tokenPurchase = z.object({ stripe_checkout_session_id: z.string(), stripe_payment_intent_id: z.string().nullable(), quantity: z.number(), - tokens_granted: z.number(), + credit_micros_granted: z.number(), amount_total: z.number().nullable(), currency: z.string().nullable(), status: tokenPurchaseStatusEnum, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 95d9848f2..1c67d59a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -41,7 +41,7 @@ export interface RawChatErrorInput { } export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = - "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + "I can’t continue with the current premium model because your premium credit is exhausted. Switch to a free model or top up your credit to continue."; function getErrorMessage(error: unknown): string { if (error instanceof Error) return error.message; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 80e7bffbe..6df56f0ce 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -541,16 +541,23 @@ export type SSEEvent = data: { usage: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } >; prompt_tokens: number; completion_tokens: number; total_tokens: number; + cost_micros?: number; call_details: Array<{ model: string; prompt_tokens: number; completion_tokens: number; total_tokens: number; + cost_micros?: number; }>; }; } diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index fc970c26e..7fec60a23 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -30,9 +30,20 @@ export interface TokenUsageSummary { prompt_tokens: number; completion_tokens: number; total_tokens: number; + /** + * Total provider USD cost for this assistant turn, in micro-USD + * (1_000_000 = $1.00). Optional because rows persisted before the + * cost-credits migration won't have it. + */ + cost_micros?: number; model_breakdown?: Record< string, - { prompt_tokens: number; completion_tokens: number; total_tokens: number } + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } > | null; } diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts index 0e6234db5..f483fa9b4 100644 --- a/surfsense_web/zero/schema/user.ts +++ b/surfsense_web/zero/schema/user.ts @@ -1,11 +1,20 @@ import { number, string, table } from "@rocicorp/zero"; +/** + * Live-meter slice of the ``user`` table replicated through Zero. + * + * ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored + * as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M + * when displaying. Sensitive fields (email, hashed_password, oauth, etc.) + * are intentionally omitted via the Postgres column-list publication so + * they never enter WAL replication. + */ export const userTable = table("user") .columns({ id: string(), pagesLimit: number().from("pages_limit"), pagesUsed: number().from("pages_used"), - premiumTokensLimit: number().from("premium_tokens_limit"), - premiumTokensUsed: number().from("premium_tokens_used"), + premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"), + premiumCreditMicrosUsed: number().from("premium_credit_micros_used"), }) .primaryKey("id"); From 47b2994ec76c88b45c1cae55116372be87368e9f Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 19:18:53 -0700 Subject: [PATCH 294/299] feat: fixed vision/image provider specific errors and fixed podcast/video streaming --- .../app/agents/new_chat/llm_config.py | 132 +++-- .../agents/new_chat/tools/generate_image.py | 43 +- surfsense_backend/app/config/__init__.py | 26 + .../app/routes/image_generation_routes.py | 49 +- .../app/routes/new_llm_config_routes.py | 73 ++- .../app/routes/vision_llm_routes.py | 8 +- .../app/schemas/image_generation.py | 9 + .../app/schemas/new_llm_config.py | 36 ++ surfsense_backend/app/schemas/vision_llm.py | 9 + .../app/services/auto_model_pin_service.py | 87 ++- .../app/services/billable_calls.py | 258 ++++++-- .../app/services/image_gen_router_service.py | 21 +- .../app/services/llm_router_service.py | 2 - surfsense_backend/app/services/llm_service.py | 35 +- .../openrouter_integration_service.py | 38 +- .../app/services/provider_api_base.py | 1 - .../app/services/provider_capabilities.py | 280 +++++++++ .../app/tasks/celery_tasks/__init__.py | 100 +++- .../app/tasks/celery_tasks/connector_tasks.py | 197 ++----- .../celery_tasks/document_reindex_tasks.py | 12 +- .../app/tasks/celery_tasks/document_tasks.py | 188 ++---- .../app/tasks/celery_tasks/obsidian_tasks.py | 20 +- .../app/tasks/celery_tasks/podcast_tasks.py | 50 +- .../celery_tasks/schedule_checker_task.py | 12 +- .../stale_notification_cleanup_task.py | 14 +- .../stripe_reconciliation_task.py | 19 +- .../celery_tasks/video_presentation_tasks.py | 58 +- .../app/tasks/chat/stream_new_chat.py | 135 ++++- .../scripts/verify_chat_image_capability.py | 558 ++++++++++++++++++ .../routes/test_byok_supports_image_input.py | 110 ++++ .../routes/test_global_configs_is_premium.py | 184 ++++++ ...t_global_new_llm_configs_supports_image.py | 106 ++++ .../services/test_auto_pin_image_aware.py | 286 +++++++++ .../tests/unit/services/test_billable_call.py | 131 +++- .../test_image_gen_api_base_defense.py | 177 ++++++ .../test_openrouter_integration_service.py | 8 + .../unit/services/test_provider_api_base.py | 107 ++++ .../services/test_provider_capabilities.py | 244 ++++++++ .../services/test_supports_image_input.py | 281 +++++++++ .../test_vision_llm_api_base_defense.py | 89 +++ .../unit/tasks/test_celery_async_runner.py | 318 ++++++++++ .../tests/unit/tasks/test_podcast_billing.py | 67 ++- .../test_stream_new_chat_image_safety_net.py | 119 ++++ .../tasks/test_video_presentation_billing.py | 70 ++- .../assistant-ui/assistant-message.tsx | 4 +- .../components/new-chat/model-selector.tsx | 46 +- .../components/pricing/pricing-section.tsx | 4 +- .../settings/image-model-manager.tsx | 73 ++- .../settings/more-pages-content.tsx | 4 +- .../settings/vision-model-manager.tsx | 73 ++- .../components/tool-ui/generate-podcast.tsx | 21 +- surfsense_web/contexts/login-gate.tsx | 4 +- .../contracts/types/new-llm-config.types.ts | 30 + surfsense_web/next.config.ts | 6 + 54 files changed, 4469 insertions(+), 563 deletions(-) create mode 100644 surfsense_backend/app/services/provider_capabilities.py create mode 100644 surfsense_backend/scripts/verify_chat_image_capability.py create mode 100644 surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py create mode 100644 surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py create mode 100644 surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py create mode 100644 surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py create mode 100644 surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py create mode 100644 surfsense_backend/tests/unit/services/test_provider_api_base.py create mode 100644 surfsense_backend/tests/unit/services/test_provider_capabilities.py create mode 100644 surfsense_backend/tests/unit/services/test_supports_image_input.py create mode 100644 surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py create mode 100644 surfsense_backend/tests/unit/tasks/test_celery_async_runner.py create mode 100644 surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 99bb719f6..bc37bf1c4 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "XAI": "xai", - "BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "COMETAPI": "cometapi", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} +# Provider mapping for LiteLLM model string construction. +# +# Single source of truth lives in +# :mod:`app.services.provider_capabilities` so the YAML loader (which +# runs during ``app.config`` class-body init) can resolve provider +# prefixes without dragging the agent / tools tree into module load +# order. Re-exported here under the historical ``PROVIDER_MAP`` name +# so existing callers (``llm_router_service``, ``image_gen_router_service``, +# tests) keep working unchanged. +from app.services.provider_capabilities import ( # noqa: E402 + _PROVIDER_PREFIX_MAP as PROVIDER_MAP, +) def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: @@ -178,6 +155,17 @@ class AgentConfig: anonymous_enabled: bool = False quota_reserve_tokens: int | None = None + # Capability flag: best-effort True for the chat selector / catalog. + # Resolved via :func:`provider_capabilities.derive_supports_image_input` + # which prefers OpenRouter's ``architecture.input_modalities`` and + # otherwise consults LiteLLM's authoritative model map. Default True + # is the conservative-allow stance — the streaming-task safety net + # (``is_known_text_only_chat_model``) is the *only* place a False + # actually blocks a request. Setting this to False here without an + # authoritative source would silently hide vision-capable models + # (the regression we're fixing). + supports_image_input: bool = True + @classmethod def from_auto_mode(cls) -> "AgentConfig": """ @@ -203,6 +191,12 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # Auto routes across the configured pool, which usually + # contains at least one vision-capable deployment; the router + # will surface a 404 from a non-vision deployment as a normal + # ``allowed_fails`` event and fail over rather than blocking + # the request outright. + supports_image_input=True, ) @classmethod @@ -216,10 +210,24 @@ class AgentConfig: Returns: AgentConfig instance """ - return cls( - provider=config.provider.value + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + + provider_value = ( + config.provider.value if hasattr(config.provider, "value") - else str(config.provider), + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + return cls( + provider=provider_value, model_name=config.model_name, api_key=config.api_key, api_base=config.api_base, @@ -235,6 +243,16 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # BYOK rows have no operator-curated capability flag, so we + # ask LiteLLM (default-allow on unknown). The streaming + # safety net still blocks if the model is *explicitly* + # marked text-only. + supports_image_input=derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ), ) @classmethod @@ -253,15 +271,46 @@ class AgentConfig: Returns: AgentConfig instance """ + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + # Get system instructions from YAML, default to empty string system_instructions = yaml_config.get("system_instructions", "") + provider = yaml_config.get("provider", "").upper() + model_name = yaml_config.get("model_name", "") + custom_provider = yaml_config.get("custom_provider") + litellm_params = yaml_config.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + # Explicit YAML override wins; otherwise derive from LiteLLM / + # OpenRouter modalities. The YAML loader already populates this + # field, but this method is also called from + # ``load_global_llm_config_by_id``'s file fallback (hot reload), + # so we re-derive here for safety. The bool() coercion preserves + # the loader's behaviour for explicit ``true`` / ``false`` + # strings that PyYAML may surface. + if "supports_image_input" in yaml_config: + supports_image_input = bool(yaml_config.get("supports_image_input")) + else: + supports_image_input = derive_supports_image_input( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ) + return cls( - provider=yaml_config.get("provider", "").upper(), - model_name=yaml_config.get("model_name", ""), + provider=provider, + model_name=model_name, api_key=yaml_config.get("api_key", ""), api_base=yaml_config.get("api_base"), - custom_provider=yaml_config.get("custom_provider"), + custom_provider=custom_provider, litellm_params=yaml_config.get("litellm_params"), # Prompt configuration from YAML (with defaults for backwards compatibility) system_instructions=system_instructions if system_instructions else None, @@ -276,6 +325,7 @@ class AgentConfig: is_premium=yaml_config.get("billing_tier", "free") == "premium", anonymous_enabled=yaml_config.get("anonymous_enabled", False), quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"), + supports_image_input=supports_image_input, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index 3803fa39c..9e287ac51 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -31,6 +31,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -49,12 +50,16 @@ _PROVIDER_MAP = { } +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + prefix = _resolve_provider_prefix(provider, custom_provider) return f"{prefix}/{model_name}" @@ -146,14 +151,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -175,14 +184,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 2aeeafb34..97b4cf509 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -47,11 +47,37 @@ def load_global_llm_configs(): data = yaml.safe_load(f) configs = data.get("global_llm_configs", []) + # Lazy import keeps the `app.config` -> `app.services` edge one-way + # and matches the `provider_api_base` pattern used elsewhere. + from app.services.provider_capabilities import derive_supports_image_input + seen_slugs: dict[str, int] = {} for cfg in configs: cfg.setdefault("billing_tier", "free") cfg.setdefault("anonymous_enabled", False) cfg.setdefault("seo_enabled", False) + # Capability flag: explicit YAML override always wins. When the + # operator has not annotated the model, defer to LiteLLM's + # authoritative model map (`supports_vision`) which already + # knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are + # vision-capable. Unknown / unmapped models default-allow so + # we don't lock the user out of a freshly added third-party + # entry; the streaming-task safety net (driven by + # `is_known_text_only_chat_model`) is the only place a False + # actually blocks a request. + if "supports_image_input" not in cfg: + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + cfg["supports_image_input"] = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) if cfg.get("seo_enabled") and cfg.get("seo_slug"): slug = cfg["seo_slug"] diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 34ed80207..018234ad5 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -46,6 +46,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + """Resolve the LiteLLM provider prefix used in model strings.""" + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: """Build a litellm model string from provider + model_name.""" - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" async def _resolve_billing_for_image_gen( @@ -187,12 +192,18 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - model_string = _build_model_string( - cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -214,12 +225,18 @@ async def _execute_image_generation( if not db_cfg: raise ValueError(f"Image generation config {config_id} not found") - model_string = _build_model_string( - db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: @@ -277,10 +294,12 @@ async def get_global_image_gen_configs( # Auto mode currently treated as free until per-deployment # billing-tier surfacing lands (see _resolve_billing_for_image_gen). "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -293,7 +312,11 @@ async def get_global_image_gen_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", "quota_reserve_micros": cfg.get("quota_reserve_micros"), } ) diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 20779a309..e090a1a7c 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -29,6 +29,7 @@ from app.schemas import ( NewLLMConfigUpdate, ) from app.services.llm_service import validate_llm_config +from app.services.provider_capabilities import derive_supports_image_input from app.users import current_active_user from app.utils.rbac import check_permission @@ -36,6 +37,39 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: + """Augment a BYOK chat config row with the derived ``supports_image_input``. + + There is no DB column for ``supports_image_input`` — the value is + resolved at the API boundary from LiteLLM's authoritative model map + (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps + the response shape consistent across list / detail / create / update + endpoints without having to remember to set the field at every call + site. + """ + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + supports_image_input = derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ) + # ``model_validate`` runs the Pydantic conversion using the ORM + # attribute access path enabled by ``ConfigDict(from_attributes=True)``, + # then we layer the derived field on. ``model_copy(update=...)`` keeps + # the surface immutable from the caller's perspective. + base_read = NewLLMConfigRead.model_validate(config) + return base_read.model_copy(update={"supports_image_input": supports_image_input}) + + # ============================================================================= # Global Configs Routes # ============================================================================= @@ -84,11 +118,41 @@ async def get_global_new_llm_configs( "seo_title": None, "seo_description": None, "quota_reserve_tokens": None, + # Auto routes across the configured pool, which usually + # includes at least one vision-capable deployment, so + # treat Auto as image-capable. The router itself will + # still pick a vision-capable deployment for messages + # carrying image_url blocks (LiteLLM Router falls back + # on ``404`` per its ``allowed_fails`` policy). + "supports_image_input": True, } ) # Add individual global configs for cfg in global_configs: + # Capability resolution: explicit value (YAML override or OR + # `_supports_image_input(model)` payload baked in by the + # OpenRouter integration service) wins. Fall back to the + # LiteLLM-driven helper which default-allows on unknown so + # we don't hide vision-capable models that happen to lack a + # YAML annotation. The streaming task safety net is the + # only place a False ever blocks. + if "supports_image_input" in cfg: + supports_image_input = bool(cfg.get("supports_image_input")) + else: + cfg_litellm_params = cfg.get("litellm_params") or {} + cfg_base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + supports_image_input = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=cfg_base_model, + custom_provider=cfg.get("custom_provider"), + ) + safe_config = { "id": cfg.get("id"), "name": cfg.get("name"), @@ -113,6 +177,7 @@ async def get_global_new_llm_configs( "seo_title": cfg.get("seo_title"), "seo_description": cfg.get("seo_description"), "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "supports_image_input": supports_image_input, } safe_configs.append(safe_config) @@ -171,7 +236,7 @@ async def create_new_llm_config( await session.commit() await session.refresh(db_config) - return db_config + return _serialize_byok_config(db_config) except HTTPException: raise @@ -213,7 +278,7 @@ async def list_new_llm_configs( .limit(limit) ) - return result.scalars().all() + return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] except HTTPException: raise @@ -268,7 +333,7 @@ async def get_new_llm_config( "You don't have permission to view LLM configurations in this search space", ) - return config + return _serialize_byok_config(config) except HTTPException: raise @@ -360,7 +425,7 @@ async def update_new_llm_config( await session.commit() await session.refresh(config) - return config + return _serialize_byok_config(config) except HTTPException: raise diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 4f7e9f725..e4f08f604 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -85,10 +85,12 @@ async def get_global_vision_llm_configs( # Auto mode treated as free until per-deployment billing-tier # surfacing lands; see ``get_vision_llm`` for parity. "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -101,7 +103,11 @@ async def get_global_vision_llm_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), "input_cost_per_token": cfg.get("input_cost_per_token"), "output_cost_per_token": cfg.get("output_cost_per_token"), diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index facca7b86..4262b2b3f 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel): default="free", description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) quota_reserve_micros: int | None = Field( default=None, description=( diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 9cc1fce58..e64478d38 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (no DB column). Default + # True matches the conservative-allow stance — a BYOK row that the + # route forgot to augment is not pre-judged. The streaming-task + # safety net is the only place a False actually blocks a request. + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map " + "(``litellm.supports_vision``) — there is no DB column. " + "Default True is the conservative-allow stance for unknown / " + "unmapped models." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (see NewLLMConfigRead). + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map. " + "Default True is the conservative-allow stance." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel): seo_title: str | None = None seo_description: str | None = None quota_reserve_tokens: int | None = None + supports_image_input: bool = Field( + default=True, + description=( + "Whether the model accepts image inputs (multimodal vision). " + "Derived server-side: OpenRouter dynamic configs use " + "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " + "authoritative model map (``litellm.supports_vision``). The " + "new-chat selector hints with a 'No image' badge when this is " + "False and there are pending image attachments. The streaming " + "task fails fast only when LiteLLM *explicitly* marks a model " + "as text-only — unknown / unmapped models default-allow." + ), + ) # ============================================================================= diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py index e55333a9d..d0eeaf5c6 100644 --- a/surfsense_backend/app/schemas/vision_llm.py +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel): default="free", description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) quota_reserve_tokens: int | None = Field( default=None, description=( diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 3a2c681b7..4f045ba02 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None: _healthy_until.pop(int(config_id), None) -def _global_candidates() -> list[dict]: +def _cfg_supports_image_input(cfg: dict) -> bool: + """True if the global cfg can accept image inputs. + + Prefers the explicit ``supports_image_input`` flag (set by the YAML + loader / OpenRouter integration). Falls back to a LiteLLM lookup so + a YAML entry whose flag was somehow stripped doesn't get wrongly + excluded. Default-allows on unknown — the streaming-task safety net + is the actual block, not this filter. + """ + if "supports_image_input" in cfg: + return bool(cfg.get("supports_image_input")) + # Lazy import: provider_capabilities -> llm_config -> services chain; + # importing at module load would create an init-order cycle through + # ``app.config``. + from app.services.provider_capabilities import derive_supports_image_input + + cfg_litellm_params = cfg.get("litellm_params") or {} + base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + return derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + + +def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers can't be picked as the thread's pin. Also excludes configs currently in runtime cooldown (e.g. temporary 429 bursts). + + When ``requires_image_input`` is True (image turn), additionally + filters out configs whose ``supports_image_input`` resolves to False + so a text-only deployment can't be pinned for an image request. """ candidates = [ cfg @@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]: if _is_usable_global_config(cfg) and not cfg.get("health_gated") and not _is_runtime_cooled_down(int(cfg.get("id", 0))) + and (not requires_image_input or _cfg_supports_image_input(cfg)) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -237,11 +272,20 @@ async def resolve_or_get_pinned_llm_config_id( selected_llm_config_id: int, force_repin_free: bool = False, exclude_config_ids: set[int] | None = None, + requires_image_input: bool = False, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. + + When ``requires_image_input`` is True (the current turn carries an + ``image_url`` block), the candidate pool is filtered to vision-capable + cfgs and any existing pin that can't accept image input is treated as + invalid (force re-pin). If no vision-capable cfg is available the + function raises ``ValueError`` so the streaming task surfaces the same + friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of + silently routing the image to a text-only deployment. """ thread = ( ( @@ -274,14 +318,24 @@ async def resolve_or_get_pinned_llm_config_id( excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} candidates = [ - c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + c + for c in _global_candidates(requires_image_input=requires_image_input) + if int(c.get("id", 0)) not in excluded_ids ] if not candidates: + if requires_image_input: + # Distinguish the "no vision-capable cfg" case from generic + # "no usable cfg" so the streaming task can map this to the + # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. + raise ValueError( + "No vision-capable global LLM configs are available for Auto mode" + ) raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent - # tier switch), unless the caller explicitly requests a forced repin to free. + # tier switch), unless the caller explicitly requests a forced repin to free + # *or* the turn requires image input but the pin can't handle it. pinned_id = thread.pinned_llm_config_id if ( not force_repin_free @@ -311,6 +365,29 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=True, ) if pinned_id is not None: + # If the pin is *only* invalid because it can't handle the image + # turn (it's still a healthy, usable config in the broader pool), + # log that explicitly so operators can correlate the re-pin with + # the user's image attachment instead of suspecting a cooldown. + if requires_image_input: + try: + pinned_global = next( + c + for c in config.GLOBAL_LLM_CONFIGS + if int(c.get("id", 0)) == int(pinned_id) + ) + except StopIteration: + pinned_global = None + if pinned_global is not None and not _cfg_supports_image_input( + pinned_global + ): + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, @@ -327,6 +404,10 @@ async def resolve_or_get_pinned_llm_config_id( eligible = [c for c in candidates if _tier_of(c) != "premium"] if not eligible: + if requires_image_input: + raise ValueError( + "Auto mode could not find a vision-capable LLM config for this user and quota state" + ) raise ValueError( "Auto mode could not find an eligible LLM config for this user and quota state" ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index f5ca9818e..92ccd6a78 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it. KEY DESIGN POINTS (issue A, B): -1. **Session isolation.** ``billable_call`` takes *no* ``db_session`` - argument. All ``TokenQuotaService.premium_*`` calls and the audit-row - insert each run inside their own ``shielded_async_session()``. This - guarantees that a quota commit/rollback can never accidentally flush or - roll back rows the caller has staged in the request's main session - (e.g. a freshly-created ``ImageGeneration`` row). +1. **Session isolation.** ``billable_call`` takes no caller transaction. + All ``TokenQuotaService.premium_*`` calls and the audit-row insert run + inside their own session context. Route callers use + ``shielded_async_session()`` by default; Celery callers can provide a + worker-loop-safe session factory. This guarantees that quota + commit/rollback can never accidentally flush or roll back rows the caller + has staged in its main session (e.g. a freshly-created + ``ImageGeneration`` row). 2. **ContextVar safety.** The accumulator is scoped via :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a @@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B): from __future__ import annotations +import asyncio import logging -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress from typing import Any from uuid import UUID, uuid4 @@ -58,6 +61,12 @@ from app.services.token_tracking_service import ( logger = logging.getLogger(__name__) +AUDIT_TIMEOUT_SECONDS = 10.0 +BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset( + {"video_presentation_generation", "podcast_generation"} +) +BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]] + class QuotaInsufficientError(Exception): """Raised when ``TokenQuotaService.premium_reserve`` denies a billable @@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception): ) +class BillingSettlementError(Exception): + """Raised when a premium call completed but credit settlement failed.""" + + def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None: + self.usage_type = usage_type + self.user_id = user_id + super().__init__( + f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}" + ) + + +async def _rollback_safely(session: AsyncSession) -> None: + rollback = getattr(session, "rollback", None) + if rollback is not None: + with suppress(Exception): + await rollback() + + +async def _record_audit_best_effort( + *, + session_factory: BillableSessionFactory, + usage_type: str, + search_space_id: int, + user_id: UUID, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + cost_micros: int, + model_breakdown: dict[str, Any], + call_details: dict[str, Any] | None, + thread_id: int | None, + message_id: int | None, + audit_label: str, + timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, +) -> None: + """Persist a TokenUsage row without letting audit failure block callers. + + Premium settlement is mandatory, but TokenUsage is an audit trail. If the + audit insert or commit hangs, user-facing artifacts such as videos and + podcasts must still be able to transition to READY after settlement. + """ + audit_thread_id = ( + None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id + ) + + async def _persist() -> None: + logger.info( + "[billable_call] audit start label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + async with session_factory() as audit_session: + try: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost_micros=cost_micros, + model_breakdown=model_breakdown, + call_details=call_details, + thread_id=audit_thread_id, + message_id=message_id, + ) + logger.info( + "[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + await audit_session.commit() + logger.info( + "[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + except BaseException: + await _rollback_safely(audit_session) + raise + + try: + await asyncio.wait_for(_persist(), timeout=timeout_seconds) + except TimeoutError: + logger.warning( + "[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s " + "timeout=%.1fs total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + timeout_seconds, + total_tokens, + cost_micros, + ) + except Exception: + logger.exception( + "[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + + @asynccontextmanager async def billable_call( *, @@ -101,6 +228,8 @@ async def billable_call( thread_id: int | None = None, message_id: int | None = None, call_details: dict[str, Any] | None = None, + billable_session_factory: BillableSessionFactory | None = None, + audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, ) -> AsyncIterator[TurnTokenAccumulator]: """Wrap a single billable LLM/image call. @@ -124,6 +253,13 @@ async def billable_call( thread_id, message_id: Optional FK columns on ``TokenUsage``. call_details: Optional per-call metadata (model name, parameters) forwarded to ``record_token_usage``. + billable_session_factory: Optional async context factory used for + reserve/finalize/release/audit sessions. Defaults to + ``shielded_async_session`` for route callers; Celery callers pass + a worker-loop-safe session factory. + audit_timeout_seconds: Upper bound for TokenUsage audit persistence. + Audit failure is best-effort and does not undo successful + settlement. Yields: The ``TurnTokenAccumulator`` scoped to this call. The caller invokes @@ -134,6 +270,7 @@ async def billable_call( QuotaInsufficientError: when premium and ``premium_reserve`` denies. """ is_premium = billing_tier == "premium" + session_factory = billable_session_factory or shielded_async_session async with scoped_turn() as acc: # ---------- Free path: just audit ------------------------------- @@ -143,30 +280,22 @@ async def billable_call( finally: # Always audit, even on exception, so we capture cost when # provider returns successfully but the caller raises later. - try: - async with shielded_async_session() as audit_session: - await record_token_usage( - audit_session, - usage_type=usage_type, - search_space_id=search_space_id, - user_id=user_id, - prompt_tokens=acc.total_prompt_tokens, - completion_tokens=acc.total_completion_tokens, - total_tokens=acc.grand_total, - cost_micros=acc.total_cost_micros, - model_breakdown=acc.per_message_summary(), - call_details=call_details, - thread_id=thread_id, - message_id=message_id, - ) - await audit_session.commit() - except Exception: - logger.exception( - "[billable_call] free-path audit insert failed for " - "usage_type=%s user_id=%s", - usage_type, - user_id, - ) + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=acc.total_cost_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="free", + timeout_seconds=audit_timeout_seconds, + ) return # ---------- Premium path: reserve → execute → finalize ---------- @@ -180,7 +309,7 @@ async def billable_call( request_id = str(uuid4()) - async with shielded_async_session() as quota_session: + async with session_factory() as quota_session: reserve_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=user_id, @@ -222,7 +351,7 @@ async def billable_call( # from a downstream call, asyncio cancellation, etc.). We use # BaseException so cancellation also releases. try: - async with shielded_async_session() as quota_session: + async with session_factory() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, @@ -241,7 +370,16 @@ async def billable_call( # ---------- Success: finalize + audit ---------------------------- actual_micros = acc.total_cost_micros try: - async with shielded_async_session() as quota_session: + logger.info( + "[billable_call] finalize start user=%s usage_type=%s actual=%d " + "reserved=%d thread=%s", + user_id, + usage_type, + actual_micros, + reserve_micros, + thread_id, + ) + async with session_factory() as quota_session: final_result = await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=user_id, @@ -260,7 +398,7 @@ async def billable_call( final_result.limit, final_result.remaining, ) - except Exception: + except Exception as finalize_exc: # Last-ditch: if finalize itself fails, we must at least release # so the reservation doesn't leak. logger.exception( @@ -269,7 +407,7 @@ async def billable_call( user_id, ) try: - async with shielded_async_session() as quota_session: + async with session_factory() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, @@ -281,31 +419,28 @@ async def billable_call( "for user=%s", user_id, ) + raise BillingSettlementError( + usage_type=usage_type, + user_id=user_id, + cause=finalize_exc, + ) from finalize_exc - try: - async with shielded_async_session() as audit_session: - await record_token_usage( - audit_session, - usage_type=usage_type, - search_space_id=search_space_id, - user_id=user_id, - prompt_tokens=acc.total_prompt_tokens, - completion_tokens=acc.total_completion_tokens, - total_tokens=acc.grand_total, - cost_micros=actual_micros, - model_breakdown=acc.per_message_summary(), - call_details=call_details, - thread_id=thread_id, - message_id=message_id, - ) - await audit_session.commit() - except Exception: - logger.exception( - "[billable_call] premium-path audit insert failed for " - "usage_type=%s user_id=%s (debit was applied)", - usage_type, - user_id, - ) + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=actual_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="premium", + timeout_seconds=audit_timeout_seconds, + ) async def _resolve_agent_billing_for_search_space( @@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space( __all__ = [ + "BillingSettlementError", "QuotaInsufficientError", "_resolve_agent_billing_for_search_space", "billable_call", diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index f45a6ab63..b4de2a0bf 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,6 +20,8 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing @@ -152,12 +154,12 @@ class ImageGenRouterService: return None # Build model string + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] else: - provider = config.get("provider", "").upper() provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" + model_string = f"{provider_prefix}/{config['model_name']}" # Build litellm params litellm_params: dict[str, Any] = { @@ -165,9 +167,16 @@ class ImageGenRouterService: "api_key": config.get("api_key"), } - # Add optional api_base - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + # Resolve ``api_base`` so deployments don't silently inherit + # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against + # the wrong provider (see ``provider_api_base`` docstring). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base # Add api_version (required for Azure) if config.get("api_version"): diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 1e9d235c8..d220aa346 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -140,8 +140,6 @@ PROVIDER_MAP = { # 404-ing against an inherited Azure endpoint). Re-exported here for # backward compatibility with any external import. from app.services.provider_api_base import ( # noqa: E402 - PROVIDER_DEFAULT_API_BASE, - PROVIDER_KEY_DEFAULT_API_BASE, resolve_api_base, ) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 72c10035d..ade202c72 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -16,6 +16,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -556,22 +557,26 @@ async def get_vision_llm( return None if global_cfg.get("custom_provider"): - model_string = ( - f"{global_cfg['custom_provider']}/{global_cfg['model_name']}" - ) + provider_prefix = global_cfg["custom_provider"] + model_string = f"{provider_prefix}/{global_cfg['model_name']}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( global_cfg["provider"].upper(), global_cfg["provider"].lower(), ) - model_string = f"{prefix}/{global_cfg['model_name']}" + model_string = f"{provider_prefix}/{global_cfg['model_name']}" litellm_kwargs = { "model": model_string, "api_key": global_cfg["api_key"], } - if global_cfg.get("api_base"): - litellm_kwargs["api_base"] = global_cfg["api_base"] + api_base = resolve_api_base( + provider=global_cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=global_cfg.get("api_base"), + ) + if api_base: + litellm_kwargs["api_base"] = api_base if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) @@ -606,20 +611,26 @@ async def get_vision_llm( return None if vision_cfg.custom_provider: - model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}" + provider_prefix = vision_cfg.custom_provider + model_string = f"{provider_prefix}/{vision_cfg.model_name}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( vision_cfg.provider.value.upper(), vision_cfg.provider.value.lower(), ) - model_string = f"{prefix}/{vision_cfg.model_name}" + model_string = f"{provider_prefix}/{vision_cfg.model_name}" litellm_kwargs = { "model": model_string, "api_key": vision_cfg.api_key, } - if vision_cfg.api_base: - litellm_kwargs["api_base"] = vision_cfg.api_base + api_base = resolve_api_base( + provider=vision_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=vision_cfg.api_base, + ) + if api_base: + litellm_kwargs["api_base"] = api_base if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 0d030f04f..6454e2d58 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool: return "image" in input_mods and "text" in output_mods +def _supports_image_input(model: dict) -> bool: + """Return True if the model accepts ``image`` in its input modalities. + + Differs from :func:`_is_vision_input_model` in that it does NOT + require text output — chat-tab models always emit text already (the + chat catalog filters by ``_is_text_output_model``), so the only + extra capability we need to track per chat config is whether the + model can ingest user-attached images. The chat selector and the + streaming task both key off this flag to prevent hitting an + OpenRouter 404 ``"No endpoints found that support image input"`` + when the user uploads an image and selects a text-only model + (DeepSeek V3, Llama 3.x base, etc.). + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + return "image" in input_mods + + def _supports_tool_calling(model: dict) -> bool: """Return True if the model supports function/tool calling.""" supported = model.get("supported_parameters") or [] @@ -321,6 +339,13 @@ def _generate_configs( # account-wide quota, so per-deployment routing can't spread load # there — it just drains the shared bucket faster. "router_pool_eligible": tier == "premium", + # Capability flag derived from ``architecture.input_modalities``. + # Read by the new-chat selector to dim image-incompatible models + # when the user has pending image attachments, and by + # ``stream_new_chat`` as a fail-fast safety net before the + # OpenRouter request would otherwise 404 with + # ``"No endpoints found that support image input"``. + "supports_image_input": _supports_image_input(model), _OPENROUTER_DYNAMIC_MARKER: True, # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next @@ -398,7 +423,12 @@ def _generate_image_gen_configs( "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, - "api_base": "", + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` and 404 on + # ``image_generation/transformation`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, "litellm_params": dict(litellm_params), @@ -477,7 +507,11 @@ def _generate_vision_llm_configs( "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, - "api_base": "", + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, "tpm": free_tpm if tier == "free" else tpm, diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py index 979d7d3a1..dca1f9462 100644 --- a/surfsense_backend/app/services/provider_api_base.py +++ b/surfsense_backend/app/services/provider_api_base.py @@ -17,7 +17,6 @@ source of truth without an inter-service circular import. from __future__ import annotations - PROVIDER_DEFAULT_API_BASE: dict[str, str] = { "openrouter": "https://openrouter.ai/api/v1", "groq": "https://api.groq.com/openai/v1", diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py new file mode 100644 index 000000000..e9a1c33e1 --- /dev/null +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -0,0 +1,280 @@ +"""Capability resolution shared by chat / image / vision call sites. + +Why this exists +--------------- +The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a +single, authoritative answer to one question: *can this chat config accept +``image_url`` content blocks?* Without it, the new-chat selector can't badge +incompatible models and the streaming task can't fail fast with a friendly +error before sending an image to a text-only provider. + +Two functions, two intents: + +- :func:`derive_supports_image_input` — best-effort *True* for catalog and + UI surfacing. Default-allow: an unknown / unmapped model is treated as + capable so we never lock the user out of a freshly added or + third-party-hosted vision model. + +- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming + task's safety net. Returns True only when LiteLLM's model map *explicitly* + sets ``supports_vision=False`` (or its bare-name variant does). Anything + else — missing key, lookup exception, ``supports_vision=True`` — returns + False so the request flows through to the provider. + +Implementation rule: only public LiteLLM symbols +------------------------------------------------ +``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the +typed module surface (see ``litellm.__init__`` lazy stubs) and are stable +across releases. The private ``_is_explicitly_disabled_factory`` and +``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade +can't silently break us. + +Why the previous round's strict YAML opt-in flag failed +------------------------------------------------------- +``supports_image_input: false`` was the YAML loader's setdefault. Operators +maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI +YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to +False and the streaming gate rejected every image turn. Sourcing capability +from LiteLLM's authoritative model map (which already says +``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable + +import litellm + +logger = logging.getLogger(__name__) + + +# Provider-name → LiteLLM model-prefix map. +# +# Owned here because ``app.services.provider_capabilities`` is the +# only edge that's safe to call from ``app.config``'s YAML loader at +# class-body init time. ``app.agents.new_chat.llm_config`` re-exports +# this constant under the historical ``PROVIDER_MAP`` name; placing the +# map there directly would re-introduce the +# ``app.config -> ... -> app.agents.new_chat.tools.generate_image -> +# app.config`` cycle that prompted the move. +_PROVIDER_PREFIX_MAP: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "XAI": "xai", + "BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "COMETAPI": "cometapi", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + +def _candidate_model_strings( + *, + provider: str | None, + model_name: str | None, + base_model: str | None, + custom_provider: str | None, +) -> list[tuple[str, str | None]]: + """Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates. + + LiteLLM's capability lookup is keyed by ``model`` + (optional) + ``custom_llm_provider``. Different config sources give us different + levels of detail, so we try the most-specific keys first and fall back + to bare model names so unannotated entries (e.g. an Azure deployment + pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the + map. Order matters — the first lookup that returns a definitive answer + wins for both helpers. + """ + candidates: list[tuple[str, str | None]] = [] + seen: set[tuple[str, str | None]] = set() + + def _add(model: str | None, llm_provider: str | None) -> None: + if not model: + return + key = (model, llm_provider) + if key in seen: + return + seen.add(key) + candidates.append(key) + + provider_prefix: str | None = None + if provider: + provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) + if custom_provider: + # ``custom_provider`` overrides everything for CUSTOM/proxy setups. + provider_prefix = custom_provider + + primary_model = base_model or model_name + bare_model = model_name + + # Most-specific first: provider-prefixed identifier with explicit + # custom_llm_provider so LiteLLM won't have to guess the provider via + # ``get_llm_provider``. + if primary_model and provider_prefix: + # e.g. "azure/gpt-5.4" + custom_llm_provider="azure" + if "/" in primary_model: + _add(primary_model, provider_prefix) + else: + _add(f"{provider_prefix}/{primary_model}", provider_prefix) + + # Bare base_model (or model_name) with provider hint — handles entries + # the upstream map keys without a provider prefix (most ``gpt-*`` and + # ``claude-*`` entries do this). + if primary_model: + _add(primary_model, provider_prefix) + + # Fallback to model_name when base_model differs (e.g. an Azure + # deployment whose model_name is the deployment id but base_model is the + # canonical OpenAI sku). + if bare_model and bare_model != primary_model: + if provider_prefix and "/" not in bare_model: + _add(f"{provider_prefix}/{bare_model}", provider_prefix) + _add(bare_model, provider_prefix) + _add(bare_model, None) + + return candidates + + +def derive_supports_image_input( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, + openrouter_input_modalities: Iterable[str] | None = None, +) -> bool: + """Best-effort capability flag for the new-chat selector and catalog. + + Resolution order (first definitive answer wins): + + 1. ``openrouter_input_modalities`` (when provided as a non-empty + iterable). OpenRouter exposes ``architecture.input_modalities`` per + model and that's the authoritative source for OR dynamic configs. + 2. ``litellm.supports_vision`` against each candidate identifier from + :func:`_candidate_model_strings`. Returns True as soon as any + candidate confirms vision support. + 3. Default ``True`` — the conservative-allow stance. An unknown / + newly-added / third-party-hosted model is *not* pre-judged. The + streaming safety net (:func:`is_known_text_only_chat_model`) is the + only place a False ever blocks; everywhere else, a False here would + just hide a usable model from the user. + + Returns: + True if the model can plausibly accept image input, False only when + OpenRouter explicitly says it can't. + """ + if openrouter_input_modalities is not None: + modalities = list(openrouter_input_modalities) + if modalities: + return "image" in modalities + # Empty list explicitly published by OR — treat as "no image". + return False + + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + if litellm.supports_vision( + model=model_string, custom_llm_provider=custom_llm_provider + ): + return True + except Exception as exc: + logger.debug( + "litellm.supports_vision raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # Default-allow. ``is_known_text_only_chat_model`` is the strict gate. + return True + + +def is_known_text_only_chat_model( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, +) -> bool: + """Strict opt-out probe for the streaming-task safety net. + + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False`` for at least one candidate identifier. Missing + key, lookup exception, or ``supports_vision=True`` all return False so + the streaming task lets the request through. This is the inverse-default + of :func:`derive_supports_image_input`. + + Why two functions + ----------------- + The selector wants "show me everything that's plausibly capable" — + default-allow. The safety net wants "block only when I'm certain it + can't" — default-pass. Mixing the two intents in a single function + leads to the regression we're fixing here. + """ + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + info = litellm.get_model_info( + model=model_string, custom_llm_provider=custom_llm_provider + ) + except Exception as exc: + logger.debug( + "litellm.get_model_info raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision`` + # may be missing, None, True, or False. We only fire on explicit + # False — None / missing / True all mean "don't block". + try: + value = info.get("supports_vision") # type: ignore[union-attr] + except AttributeError: + value = None + if value is False: + return True + + return False + + +__all__ = [ + "derive_supports_image_input", + "is_known_text_only_chat_model", +] diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index 5b1f2cd13..b23359f36 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -1,10 +1,25 @@ -"""Celery tasks package.""" +"""Celery tasks package. + +Also hosts the small helpers every async celery task should use to +spin up its event loop. See :func:`run_async_celery_task` for the +canonical pattern. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Awaitable, Callable +from typing import TypeVar from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool from app.config import config +logger = logging.getLogger(__name__) + _celery_engine = None _celery_session_maker = None @@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker: _celery_engine, expire_on_commit=False ) return _celery_session_maker + + +def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None: + """Drop the shared ``app.db.engine`` connection pool synchronously. + + The shared engine (used by ``shielded_async_session`` and most + routes / services) is a module-level singleton with a real pool. + Each celery task creates a fresh ``asyncio`` event loop; asyncpg + connections cache a reference to whichever loop opened them. When + a subsequent task's loop pulls a stale connection from the pool, + SQLAlchemy's ``pool_pre_ping`` checkout crashes with:: + + AttributeError: 'NoneType' object has no attribute 'send' + File ".../asyncio/proactor_events.py", line 402, in _loop_writing + self._write_fut = self._loop._proactor.send(self._sock, data) + + or hangs forever inside the asyncpg ``Connection._cancel`` cleanup + coroutine that can never run because its loop is gone. + + Disposing the engine forces the pool to drop every cached + connection so the next checkout opens a fresh one on the current + loop. Safe to call from a task's finally block; failure is logged + but never propagated. + """ + try: + from app.db import engine as shared_engine + + loop.run_until_complete(shared_engine.dispose()) + except Exception: + logger.warning("Shared DB engine dispose() failed", exc_info=True) + + +T = TypeVar("T") + + +def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T: + """Run an async coroutine inside a fresh event loop with proper + DB-engine cleanup. + + This is the canonical entry point for every async celery task. + It performs three responsibilities that were previously copy-pasted + (incorrectly) across each task module: + + 1. Create a fresh ``asyncio`` loop and install it on the current + thread (celery's ``--pool=solo`` runs every task on the main + thread, but other pool types don't). + 2. Dispose the shared ``app.db.engine`` BEFORE the task runs so + any stale connections left over from a previous task's loop + are dropped — defends against tasks that crashed without + cleaning up. + 3. Dispose the shared engine AFTER the task runs so the + connections we opened on this loop are released before the + loop closes (avoids ``coroutine 'Connection._cancel' was + never awaited`` warnings and the next-task hang). + + Use as:: + + @celery_app.task(name="my_task", bind=True) + def my_task(self, *args): + return run_async_celery_task(lambda: _my_task_impl(*args)) + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Defense-in-depth: prior task may have crashed before + # disposing. Idempotent — no-op if pool is already empty. + _dispose_shared_db_engine(loop) + return loop.run_until_complete(coro_factory()) + finally: + # Drop any connections this task opened so they don't leak + # into the next task's loop. + _dispose_shared_db_engine(loop) + with contextlib.suppress(Exception): + loop.run_until_complete(loop.shutdown_asyncgens()) + with contextlib.suppress(Exception): + asyncio.set_event_loop(None) + loop.close() + + +__all__ = [ + "get_celery_session_maker", + "run_async_celery_task", +] diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index fe1ac19d3..08d96cfa0 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -4,7 +4,7 @@ import logging import traceback from app.celery_app import celery_app -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -49,22 +49,15 @@ def index_notion_pages_task( end_date: str, ): """Celery task to index Notion pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_notion_pages( + return run_async_celery_task( + lambda: _index_notion_pages( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_notion_pages", connector_id) raise - finally: - loop.close() async def _index_notion_pages( @@ -95,19 +88,11 @@ def index_github_repos_task( end_date: str, ): """Celery task to index GitHub repositories.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_github_repos( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_github_repos( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_github_repos( @@ -138,19 +123,11 @@ def index_confluence_pages_task( end_date: str, ): """Celery task to index Confluence pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_confluence_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_confluence_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_confluence_pages( @@ -181,22 +158,15 @@ def index_google_calendar_events_task( end_date: str, ): """Celery task to index Google Calendar events.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_google_calendar_events( + return run_async_celery_task( + lambda: _index_google_calendar_events( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_google_calendar_events", connector_id) raise - finally: - loop.close() async def _index_google_calendar_events( @@ -227,19 +197,11 @@ def index_google_gmail_messages_task( end_date: str, ): """Celery task to index Google Gmail messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_gmail_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_google_gmail_messages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_google_gmail_messages( @@ -269,22 +231,14 @@ def index_google_drive_files_task( items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options' ): """Celery task to index Google Drive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_drive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_google_drive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_google_drive_files( @@ -317,22 +271,14 @@ def index_onedrive_files_task( items_dict: dict, ): """Celery task to index OneDrive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_onedrive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_onedrive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_onedrive_files( @@ -365,22 +311,14 @@ def index_dropbox_files_task( items_dict: dict, ): """Celery task to index Dropbox folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_dropbox_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_dropbox_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_dropbox_files( @@ -414,19 +352,11 @@ def index_elasticsearch_documents_task( end_date: str, ): """Celery task to index Elasticsearch documents.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_elasticsearch_documents( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_elasticsearch_documents( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_elasticsearch_documents( @@ -457,22 +387,15 @@ def index_crawled_urls_task( end_date: str, ): """Celery task to index Web page Urls.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_crawled_urls( + return run_async_celery_task( + lambda: _index_crawled_urls( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_crawled_urls", connector_id) raise - finally: - loop.close() async def _index_crawled_urls( @@ -503,19 +426,11 @@ def index_bookstack_pages_task( end_date: str, ): """Celery task to index BookStack pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_bookstack_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_bookstack_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_bookstack_pages( @@ -546,19 +461,11 @@ def index_composio_connector_task( end_date: str | None, ): """Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio).""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_composio_connector( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_composio_connector( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_composio_connector( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py index c2dbe7700..5d6bde6c1 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py @@ -11,7 +11,7 @@ from app.db import Document from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str): document_id: ID of document to reindex user_id: ID of user who edited the document """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reindex_document(document_id, user_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _reindex_document(document_id, user_id)) async def _reindex_document(document_id: int, user_id: str): diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 9d12f91f6..c78e376bd 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -11,7 +11,7 @@ from app.celery_app import celery_app from app.config import config from app.services.notification_service import NotificationService from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.tasks.connector_indexers.local_folder_indexer import ( index_local_folder, index_uploaded_files, @@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int): ) def delete_document_task(self, document_id: int): """Celery task to delete a document and its chunks in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_document_background(document_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _delete_document_background(document_id)) async def _delete_document_background(document_id: int) -> None: @@ -153,14 +148,9 @@ def delete_folder_documents_task( folder_subtree_ids: list[int] | None = None, ): """Celery task to delete documents first, then the folder rows.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _delete_folder_documents(document_ids, folder_subtree_ids) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_folder_documents(document_ids, folder_subtree_ids) + ) async def _delete_folder_documents( @@ -209,12 +199,9 @@ async def _delete_folder_documents( ) def delete_search_space_task(self, search_space_id: int): """Celery task to delete a search space and heavy child rows in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_search_space_background(search_space_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_search_space_background(search_space_id) + ) async def _delete_search_space_background(search_space_id: int) -> None: @@ -269,18 +256,11 @@ def process_extension_document_task( search_space_id: ID of the search space user_id: ID of the user """ - # Create a new event loop for this task - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_extension_document( - individual_document_dict, search_space_id, user_id - ) + return run_async_celery_task( + lambda: _process_extension_document( + individual_document_dict, search_space_id, user_id ) - finally: - loop.close() + ) async def _process_extension_document( @@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st search_space_id: ID of the search space user_id: ID of the user """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _process_youtube_video(url, search_space_id, user_id) + ) async def _process_youtube_video(url: str, search_space_id: int, user_id: str): @@ -573,12 +549,9 @@ def process_file_upload_task( except Exception as e: logger.warning(f"[process_file_upload] Could not get file size: {e}") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_upload(file_path, filename, search_space_id, user_id) + run_async_celery_task( + lambda: _process_file_upload(file_path, filename, search_space_id, user_id) ) logger.info( f"[process_file_upload] Task completed successfully for: {filename}" @@ -589,8 +562,6 @@ def process_file_upload_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _process_file_upload( @@ -811,25 +782,17 @@ def process_file_upload_with_document_task( "File may have been removed before syncing could start." ) # Mark document as failed since file is missing - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _mark_document_failed( - document_id, - "File not found. Please re-upload the file.", - ) + run_async_celery_task( + lambda: _mark_document_failed( + document_id, + "File not found. Please re-upload the file.", ) - finally: - loop.close() + ) return - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_with_document( + run_async_celery_task( + lambda: _process_file_with_document( document_id, temp_path, filename, @@ -849,8 +812,6 @@ def process_file_upload_with_document_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _mark_document_failed(document_id: int, reason: str): @@ -1119,22 +1080,16 @@ def process_circleback_meeting_task( search_space_id: ID of the search space connector_id: ID of the Circleback connector (for deletion support) """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_circleback_meeting( - meeting_id, - meeting_name, - markdown_content, - metadata, - search_space_id, - connector_id, - ) + return run_async_celery_task( + lambda: _process_circleback_meeting( + meeting_id, + meeting_name, + markdown_content, + metadata, + search_space_id, + connector_id, ) - finally: - loop.close() + ) async def _process_circleback_meeting( @@ -1291,25 +1246,19 @@ def index_local_folder_task( target_file_paths: list[str] | None = None, ): """Celery task to index a local folder. Config is passed directly — no connector row.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_local_folder_async( - search_space_id=search_space_id, - user_id=user_id, - folder_path=folder_path, - folder_name=folder_name, - exclude_patterns=exclude_patterns, - file_extensions=file_extensions, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - target_file_paths=target_file_paths, - ) + return run_async_celery_task( + lambda: _index_local_folder_async( + search_space_id=search_space_id, + user_id=user_id, + folder_path=folder_path, + folder_name=folder_name, + exclude_patterns=exclude_patterns, + file_extensions=file_extensions, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + target_file_paths=target_file_paths, ) - finally: - loop.close() + ) async def _index_local_folder_async( @@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task( processing_mode: str = "basic", ): """Celery task to index files uploaded from the desktop app.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_uploaded_folder_files_async( - search_space_id=search_space_id, - user_id=user_id, - folder_name=folder_name, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - file_mappings=file_mappings, - use_vision_llm=use_vision_llm, - processing_mode=processing_mode, - ) + return run_async_celery_task( + lambda: _index_uploaded_folder_files_async( + search_space_id=search_space_id, + user_id=user_id, + folder_name=folder_name, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + file_mappings=file_mappings, + use_vision_llm=use_vision_llm, + processing_mode=processing_mode, ) - finally: - loop.close() + ) async def _index_uploaded_folder_files_async( @@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str: @celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) def ai_sort_search_space_task(self, search_space_id: int, user_id: str): """Full AI sort for all documents in a search space.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_search_space_async(search_space_id, user_id) + ) async def _ai_sort_search_space_async(search_space_id: int, user_id: str): @@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str): ) def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): """Incremental AI sort for a single document after indexing.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _ai_sort_document_async(search_space_id, user_id, document_id) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_document_async(search_space_id, user_id, document_id) + ) async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): diff --git a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py index 98b107af3..c6c8666f5 100644 --- a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py @@ -2,14 +2,13 @@ from __future__ import annotations -import asyncio import logging from app.celery_app import celery_app from app.db import SearchSourceConnector from app.schemas.obsidian_plugin import NotePayload from app.services.obsidian_plugin_indexer import upsert_note -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -22,18 +21,13 @@ def index_obsidian_attachment_task( user_id: str, ) -> None: """Process one Obsidian non-markdown attachment asynchronously.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_obsidian_attachment( - connector_id=connector_id, - payload_data=payload_data, - user_id=user_id, - ) + return run_async_celery_task( + lambda: _index_obsidian_attachment( + connector_id=connector_id, + payload_data=payload_data, + user_id=user_id, ) - finally: - loop.close() + ) async def _index_obsidian_attachment( diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 937877473..8b311576e 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -3,6 +3,7 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select @@ -12,11 +13,12 @@ from app.celery_app import celery_app from app.config import config as app_config from app.db import Podcast, PodcastStatus from app.services.billable_calls import ( + BillingSettlementError, QuotaInsufficientError, _resolve_agent_billing_for_search_space, billable_call, ) -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -34,6 +36,13 @@ if sys.platform.startswith("win"): # ============================================================================= +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_content_podcast", bind=True) def generate_content_podcast_task( self, @@ -46,27 +55,22 @@ def generate_content_podcast_task( Celery task to generate podcast from source content. Updates existing podcast record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_content_podcast( + return run_async_celery_task( + lambda: _generate_content_podcast( podcast_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating content podcast: {e!s}") - loop.run_until_complete(_mark_podcast_failed(podcast_id)) + try: + run_async_celery_task(lambda: _mark_podcast_failed(podcast_id)) + except Exception: + logger.exception("Failed to mark podcast %s as failed", podcast_id) return {"status": "failed", "podcast_id": podcast_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_podcast_failed(podcast_id: int) -> None: @@ -148,11 +152,12 @@ async def _generate_content_podcast( base_model=base_model, quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, usage_type="podcast_generation", - thread_id=podcast.thread_id, call_details={ "podcast_id": podcast.id, "title": podcast.title, + "thread_id": podcast.thread_id, }, + billable_session_factory=_celery_billable_session, ): graph_result = await podcaster_graph.ainvoke( initial_state, config=graph_config @@ -173,6 +178,18 @@ async def _generate_content_podcast( "podcast_id": podcast.id, "reason": "premium_quota_exhausted", } + except BillingSettlementError: + logger.exception( + "Podcast %s: premium billing settlement failed", + podcast.id, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_settlement_failed", + } podcast_transcript = graph_result.get("podcast_transcript", []) file_path = graph_result.get("final_podcast_file_path", "") @@ -194,7 +211,14 @@ async def _generate_content_podcast( podcast.podcast_transcript = serializable_transcript podcast.file_location = file_path podcast.status = PodcastStatus.READY + logger.info( + "Podcast %s: committing READY transcript_entries=%d file=%s", + podcast.id, + len(serializable_transcript), + file_path, + ) await session.commit() + logger.info("Podcast %s: READY commit complete", podcast.id) logger.info(f"Successfully generated podcast: {podcast.id}") diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 373f04b48..e41251407 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -7,7 +7,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.utils.indexing_locks import is_connector_indexing_locked logger = logging.getLogger(__name__) @@ -20,15 +20,7 @@ def check_periodic_schedules_task(): This task runs every minute and triggers indexing for any connector whose next_scheduled_at time has passed. """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_check_and_trigger_schedules()) - finally: - loop.close() + return run_async_celery_task(_check_and_trigger_schedules) async def _check_and_trigger_schedules(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index e05ae9435..d51c85dee 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -34,7 +34,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.config import config from app.db import Document, DocumentStatus, Notification -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task(): Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task. Also marks associated pending/processing documents as failed. """ - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + async def _both() -> None: + await _cleanup_stale_notifications() + await _cleanup_stale_document_processing_notifications() - try: - loop.run_until_complete(_cleanup_stale_notifications()) - loop.run_until_complete(_cleanup_stale_document_processing_notifications()) - finally: - loop.close() + return run_async_celery_task(_both) async def _cleanup_stale_notifications(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py index 3aee1a360..ace6ef7ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import logging from datetime import UTC, datetime, timedelta @@ -18,7 +17,7 @@ from app.db import ( PremiumTokenPurchaseStatus, ) from app.routes import stripe_routes -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None: @celery_app.task(name="reconcile_pending_stripe_page_purchases") def reconcile_pending_stripe_page_purchases_task(): """Recover paid purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_page_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_page_purchases) async def _reconcile_pending_page_purchases() -> None: @@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None: @celery_app.task(name="reconcile_pending_stripe_token_purchases") def reconcile_pending_stripe_token_purchases_task(): """Recover paid token purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_token_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_token_purchases) async def _reconcile_pending_token_purchases() -> None: diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index 4f0c427d9..08f22140c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -3,6 +3,7 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select @@ -12,11 +13,12 @@ from app.celery_app import celery_app from app.config import config as app_config from app.db import VideoPresentation, VideoPresentationStatus from app.services.billable_calls import ( + BillingSettlementError, QuotaInsufficientError, _resolve_agent_billing_for_search_space, billable_call, ) -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -29,6 +31,13 @@ if sys.platform.startswith("win"): ) +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_video_presentation", bind=True) def generate_video_presentation_task( self, @@ -41,27 +50,30 @@ def generate_video_presentation_task( Celery task to generate video presentation from source content. Updates existing video presentation record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_video_presentation( + return run_async_celery_task( + lambda: _generate_video_presentation( video_presentation_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating video presentation: {e!s}") - loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id)) + # Mark FAILED in a fresh loop — the previous loop is closed. + # Swallow secondary failures; the row will simply stay in + # GENERATING and be flushed by the periodic stale cleanup. + try: + run_async_celery_task( + lambda: _mark_video_presentation_failed(video_presentation_id) + ) + except Exception: + logger.exception( + "Failed to mark video presentation %s as failed", + video_presentation_id, + ) return {"status": "failed", "video_presentation_id": video_presentation_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_video_presentation_failed(video_presentation_id: int) -> None: @@ -150,11 +162,12 @@ async def _generate_video_presentation( base_model=base_model, quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, usage_type="video_presentation_generation", - thread_id=video_pres.thread_id, call_details={ "video_presentation_id": video_pres.id, "title": video_pres.title, + "thread_id": video_pres.thread_id, }, + billable_session_factory=_celery_billable_session, ): graph_result = await video_presentation_graph.ainvoke( initial_state, config=graph_config @@ -175,6 +188,18 @@ async def _generate_video_presentation( "video_presentation_id": video_pres.id, "reason": "premium_quota_exhausted", } + except BillingSettlementError: + logger.exception( + "VideoPresentation %s: premium billing settlement failed", + video_pres.id, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_settlement_failed", + } # Serialize slides (parsed content + audio info merged) slides_raw = graph_result.get("slides", []) @@ -205,7 +230,14 @@ async def _generate_video_presentation( video_pres.slides = serializable_slides video_pres.scene_codes = serializable_scene_codes video_pres.status = VideoPresentationStatus.READY + logger.info( + "VideoPresentation %s: committing READY slides=%d scene_codes=%d", + video_pres.id, + len(serializable_slides), + len(serializable_scene_codes), + ) await session.commit() + logger.info("VideoPresentation %s: READY commit complete", video_pres.id) logger.info(f"Successfully generated video presentation: {video_pres.id}") diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 31c0d7d6d..c6ac3311a 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1506,10 +1506,10 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else "Podcast" ) - if podcast_status == "processing": + if podcast_status in ("pending", "generating", "processing"): completed_items = [ f"Title: {podcast_title}", - "Audio generation started", + "Podcast generation started", "Processing in background...", ] elif podcast_status == "already_generating": @@ -1518,7 +1518,7 @@ async def _stream_agent_events( "Podcast already in progress", "Please wait for it to complete", ] - elif podcast_status == "error": + elif podcast_status in ("failed", "error"): error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) @@ -1528,6 +1528,11 @@ async def _stream_agent_events( f"Title: {podcast_title}", f"Error: {error_msg[:50]}", ] + elif podcast_status in ("ready", "success"): + completed_items = [ + f"Title: {podcast_title}", + "Podcast ready", + ] else: completed_items = last_active_step_items yield streaming_service.format_thinking_step( @@ -1710,20 +1715,28 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else {"result": tool_output}, ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "success" + if isinstance(tool_output, dict) and tool_output.get("status") in ( + "pending", + "generating", + "processing", + ): + yield streaming_service.format_terminal_info( + f"Podcast queued: {tool_output.get('title', 'Podcast')}", + "success", + ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "ready", + "success", ): yield streaming_service.format_terminal_info( f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", "success", ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "failed", + "error", + ): + error_msg = tool_output.get("error", "Unknown error") yield streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", @@ -2292,6 +2305,11 @@ async def stream_new_chat( ) _t0 = time.perf_counter() + # Image-bearing turns force the Auto-pin resolver to filter the + # candidate pool to vision-capable cfgs (and force-repin a + # text-only existing pin). For explicit selections this flag is + # a no-op — the resolver returns the user's chosen id unchanged. + _requires_image_input = bool(user_image_data_urls) try: llm_config_id = ( await resolve_or_get_pinned_llm_config_id( @@ -2300,13 +2318,29 @@ async def stream_new_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=llm_config_id, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: + # Auto-pin's "no vision-capable cfg" path raises a ValueError + # whose message we map to the friendly image-input SSE error + # so the user sees the same message regardless of whether + # the gate fired in Auto-mode or in the agent_config check + # below. + error_code = ( + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + if _requires_image_input and "vision-capable" in str(pin_error) + else "SERVER_ERROR" + ) + error_kind = ( + "user_error" + if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + else "server_error" + ) yield _emit_stream_error( message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", + error_kind=error_kind, + error_code=error_code, ) yield streaming_service.format_done() return @@ -2326,6 +2360,50 @@ async def stream_new_chat( llm_config_id, ) + # Capability safety net: a turn carrying user-uploaded images + # cannot be routed to a chat config that LiteLLM's authoritative + # model map *explicitly* marks as text-only (``supports_vision`` + # set to False). The check is intentionally narrow — it only + # fires when LiteLLM is *certain* the model can't accept image + # input. Unknown / unmapped / vision-capable models pass + # through. Without this guard a known-text-only model would 404 + # at the provider with ``"No endpoints found that support image + # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk; + # failing here lets us return a friendly message that tells the + # user what to change. + if user_image_data_urls and agent_config is not None: + from app.services.provider_capabilities import ( + is_known_text_only_chat_model, + ) + + agent_litellm_params = agent_config.litellm_params or {} + agent_base_model = ( + agent_litellm_params.get("base_model") + if isinstance(agent_litellm_params, dict) + else None + ) + if is_known_text_only_chat_model( + provider=agent_config.provider, + model_name=agent_config.model_name, + base_model=agent_base_model, + custom_provider=agent_config.custom_provider, + ): + model_label = ( + agent_config.config_name or agent_config.model_name or "model" + ) + yield _emit_stream_error( + message=( + f"The selected model ({model_label}) does not support " + "image input. Switch to a vision-capable model " + "(e.g. GPT-4o, Claude, Gemini) or remove the image " + "attachment and try again." + ), + error_kind="user_error", + error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + ) + yield streaming_service.format_done() + return + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id and agent_config.is_premium @@ -2366,6 +2444,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, force_repin_free=True, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: @@ -2470,6 +2549,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: @@ -2804,6 +2884,7 @@ async def stream_new_chat( from litellm import acompletion from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator _turn_accumulator.set(None) @@ -2824,11 +2905,32 @@ async def stream_new_chat( model="auto", messages=messages ) else: + # Apply the same ``api_base`` cascade chat / vision / + # image-gen call sites use so we never inherit + # ``litellm.api_base`` (commonly set by + # ``AZURE_OPENAI_ENDPOINT``) when the chat config + # itself ships an empty ``api_base``. Without this + # the title-gen on an OpenRouter chat config would + # 404 against the inherited Azure endpoint — see + # ``provider_api_base`` docstring for the same + # bug repro on the image-gen / vision paths. + raw_model = getattr(llm, "model", "") or "" + provider_prefix = ( + raw_model.split("/", 1)[0] if "/" in raw_model else None + ) + provider_value = ( + agent_config.provider if agent_config is not None else None + ) + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) response = await acompletion( - model=llm.model, + model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=getattr(llm, "api_base", None), + api_base=title_api_base, ) usage_info = None @@ -2953,6 +3055,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py new file mode 100644 index 000000000..a49d4eab2 --- /dev/null +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -0,0 +1,558 @@ +"""End-to-end smoke test for vision / image config wiring. + +Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and +exercises every chat / vision / image-generation config + the OpenRouter +dynamic catalog. For each config the script: + +1. Reports the resolver classification (catalog-allow vs strict-block). +2. Optionally fires a tiny live API call against the provider: + - Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt + ``"reply with one word: ok"``. + - Vision configs: same, against the dedicated vision router pool. + - Image-gen configs: ``litellm.aimage_generation`` with a single tiny + prompt and ``n=1``. + - OpenRouter integration: samples one chat, one vision, one image-gen + model from the dynamically fetched catalog. + +Usage:: + + python -m scripts.verify_chat_image_capability # capability + connectivity + python -m scripts.verify_chat_image_capability --no-live # capability resolver only + +The script is meant to be runnable from the repository root or from +``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the +end so it's usable as a CI smoke check too. + +Live-mode caveat: each successful call costs a small amount of provider +credit (a few tokens or one tiny generated image per config). The +default size for image generation is ``1024x1024`` because Azure +GPT-image deployments reject smaller sizes; OpenRouter image-gen models +generally accept the same size. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Any + +# Bootstrap the surfsense_backend package on sys.path so the script runs +# from the repo root or from `surfsense_backend/` interchangeably. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BACKEND_ROOT = os.path.dirname(_HERE) +if _BACKEND_ROOT not in sys.path: + sys.path.insert(0, _BACKEND_ROOT) + +import litellm # noqa: E402 + +from app.config import config # noqa: E402 +from app.services.openrouter_integration_service import ( # noqa: E402 + _OPENROUTER_DYNAMIC_MARKER, + OpenRouterIntegrationService, +) +from app.services.provider_api_base import resolve_api_base # noqa: E402 +from app.services.provider_capabilities import ( # noqa: E402 + derive_supports_image_input, + is_known_text_only_chat_model, +) + +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", +) +# Quiet down LiteLLM's verbose router/cost logs so the script output is +# scannable. +logging.getLogger("LiteLLM").setLevel(logging.ERROR) +logging.getLogger("litellm").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) + +# 1x1 transparent PNG — used as the cheapest possible vision payload. +_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" +_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}" + + +# --------------------------------------------------------------------------- +# Result accounting +# --------------------------------------------------------------------------- + + +@dataclass +class ProbeResult: + label: str + surface: str + config_id: int | str + capability_ok: bool | None = None + capability_note: str = "" + live_ok: bool | None = None + live_note: str = "" + duration_s: float = 0.0 + + +@dataclass +class Report: + results: list[ProbeResult] = field(default_factory=list) + + def add(self, r: ProbeResult) -> None: + self.results.append(r) + + def render(self) -> int: + passed = failed = skipped = 0 + print() + print("=" * 92) + print( + f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes" + ) + print("-" * 92) + for r in self.results: + + def _flag(value: bool | None) -> str: + if value is None: + return "skip" + return "ok" if value else "fail" + + cap = _flag(r.capability_ok) + live = _flag(r.live_ok) + if r.capability_ok is False or r.live_ok is False: + failed += 1 + elif r.capability_ok is None and r.live_ok is None: + skipped += 1 + else: + passed += 1 + print( + f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} " + f"{r.duration_s:>5.2f}s {r.label}" + ) + if r.capability_note: + print(f" cap: {r.capability_note}") + if r.live_note: + print(f" live: {r.live_note}") + print("-" * 92) + print( + f"Total: {passed} ok / {failed} fail / {skipped} skip " + f"(of {len(self.results)} probes)" + ) + print("=" * 92) + return failed + + +# --------------------------------------------------------------------------- +# Capability probes (no network) +# --------------------------------------------------------------------------- + + +def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: + """For chat configs the catalog flag is *expected* True (vision-capable + pool). The probe reports both the resolver value and the strict + safety-net value to surface any drift between them.""" + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + cap = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + block = is_known_text_only_chat_model( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + note = f"derive={cap} strict_block={block}" + if not cap and not block: + # Resolver said False but strict gate is also False — that means + # OR modalities published [text] explicitly. Surface it. + note += " (OR modality says text-only)" + # We accept a True derive *or* (False derive AND False block) as + # 'capability ok' — either way, the streaming task will flow through. + ok = cap or not block + return ok, note + + +def _build_chat_model_string(cfg: dict) -> str: + if cfg.get("custom_provider"): + return f"{cfg['custom_provider']}/{cfg['model_name']}" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + return f"{prefix}/{cfg['model_name']}" + + +# --------------------------------------------------------------------------- +# Live probes (network calls) +# --------------------------------------------------------------------------- + + +async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: + """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" + model_string = _build_chat_model_string(cfg) + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=model_string.split("/", 1)[0], + config_api_base=cfg.get("api_base") or None, + ) + kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "reply with one word: ok"}, + { + "type": "image_url", + "image_url": {"url": _TINY_PNG_DATA_URL}, + }, + ], + } + ], + "max_tokens": 16, + "timeout": 60, + } + if api_base: + kwargs["api_base"] = api_base + if cfg.get("litellm_params"): + # Strip pricing keys — they're tracking-only and confuse some + # provider validators (e.g. azure/openai reject unknown kwargs + # in strict mode). + merged = { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + kwargs.update(merged) + try: + resp = await litellm.acompletion(**kwargs) + except Exception as exc: + return False, f"{type(exc).__name__}: {exc}" + text = resp.choices[0].message.content if resp.choices else "" + return True, f"got reply ({(text or '').strip()[:40]!r})" + + +# Gemini image models occasionally return zero-length ``data`` for the +# minimal "red dot on white" prompt (provider-side safety / empty-output +# quirk reproducible against ``google/gemini-2.5-flash-image`` even when +# the request itself succeeds). Use a more naturalistic prompt and +# retry once with a different one before giving up. +_IMAGE_GEN_PROMPTS: tuple[str, ...] = ( + "A simple icon of a coffee cup, flat illustration", + "A small green leaf on a white background", +) + + +async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: + """Generate one tiny image to verify the deployment is reachable.""" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + if cfg.get("custom_provider"): + prefix = cfg["custom_provider"] + else: + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + model_string = f"{prefix}/{cfg['model_name']}" + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=prefix, + config_api_base=cfg.get("api_base") or None, + ) + base_kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "n": 1, + "size": "1024x1024", + "timeout": 120, + } + if api_base: + base_kwargs["api_base"] = api_base + if cfg.get("api_version"): + base_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + base_kwargs.update( + { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + ) + + last_note = "" + for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1): + try: + resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs) + except Exception as exc: + last_note = f"{type(exc).__name__}: {exc}" + continue + data_count = len(getattr(resp, "data", None) or []) + if data_count > 0: + return True, ( + f"received {data_count} image(s) on attempt {attempt} " + f"(prompt={prompt!r})" + ) + last_note = ( + f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})" + ) + return False, last_note + + +# --------------------------------------------------------------------------- +# Probe drivers +# --------------------------------------------------------------------------- + + +def _is_or_dynamic(cfg: dict) -> bool: + return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER)) + + +async def probe_chat_configs(report: Report, *, live: bool) -> None: + print("\n[chat configs from global_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_LLM_CONFIGS: + # Skip OR dynamic entries here — handled in the OR section so + # the YAML / OR split stays clear in the report. + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="chat-yaml", + config_id=cfg.get("id"), + ) + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_vision_configs(report: Report, *, live: bool) -> None: + print("\n[vision configs from global_vision_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="vision", + config_id=cfg.get("id"), + ) + # For vision configs, capability is implied — they're in the + # dedicated vision pool. Run the same resolver to flag any + # surprise disagreement. + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_image_gen_configs(report: Report, *, live: bool) -> None: + print( + "\n[image generation configs from global_image_generation_configs (YAML-static)]" + ) + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="image-gen", + config_id=cfg.get("id"), + ) + # Image gen configs don't have a "supports_image_input" flag; + # the catalog tracks output, not input. Mark capability as None + # (skip) for the report. + if live: + t0 = time.perf_counter() + ok, note = await _live_image_gen_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: + """Sample one chat (vision-capable), one vision, one image-gen model + from the live OpenRouter catalogue. Doesn't iterate the full pool + (would be hundreds of probes); just validates the integration end- + to-end on a representative model from each surface.""" + print("\n[OpenRouter integration: sampled probes]") + settings = config.OPENROUTER_INTEGRATION_SETTINGS + if not settings: + report.add( + ProbeResult( + label="OpenRouter integration", + surface="openrouter", + config_id="settings", + capability_ok=None, + capability_note="openrouter_integration disabled in YAML — skipping", + live_ok=None, + ) + ) + return + + service = OpenRouterIntegrationService.get_instance() + or_chat = [ + c + for c in config.GLOBAL_LLM_CONFIGS + if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") + ] + or_vision = [ + c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" + ] + or_image_gen = [ + c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" + ] + + # Pick one representative per provider family per surface so a single + # broken vendor (e.g. Anthropic key revoked, Google quota exceeded) + # surfaces independently of the others. Each needle matches the + # OpenRouter ``model_name`` prefix; the first match wins. + def _pick_first(pool: list[dict], needle: str) -> dict | None: + for c in pool: + if (c.get("model_name") or "").lower().startswith(needle): + return c + return None + + chat_picks = [ + ("or-chat", _pick_first(or_chat, "openai/gpt-4o")), + ("or-chat", _pick_first(or_chat, "anthropic/claude")), + ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), + ] + vision_picks = [ + ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), + ("or-vision", _pick_first(or_vision, "anthropic/claude")), + ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), + ] + image_picks = [ + ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), + # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` + # / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match + # the actual prefix. + ("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")), + ] + + print( + f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" + ) + + for surface, picked in chat_picks + vision_picks + image_picks: + if not picked: + report.add( + ProbeResult( + label=f"<no candidate for {surface}>", + surface=surface, + config_id="-", + capability_ok=None, + capability_note="no candidate found in OR catalog", + ) + ) + continue + runner = ( + _live_image_gen_call if surface == "or-image" else _live_chat_image_call + ) + result = ProbeResult( + label=str(picked.get("model_name")), + surface=surface, + config_id=picked.get("id"), + ) + if surface != "or-image": + cap_ok, cap_note = _probe_chat_capability(picked) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await runner(picked) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +async def main(args: argparse.Namespace) -> int: + print("Loaded global configs:") + print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") + print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") + print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") + print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") + + # Initialize the OpenRouter integration so the catalog is populated + # (this is what main.py does at startup). It's idempotent. + if config.OPENROUTER_INTEGRATION_SETTINGS: + try: + from app.config import initialize_openrouter_integration + + initialize_openrouter_integration() + except Exception as exc: + print(f" WARNING: OpenRouter integration init failed: {exc}") + + print( + f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}" + ) + + report = Report() + if not args.skip_chat: + await probe_chat_configs(report, live=args.live) + if not args.skip_vision: + await probe_vision_configs(report, live=args.live) + if not args.skip_image_gen: + await probe_image_gen_configs(report, live=args.live) + if not args.skip_openrouter: + await probe_openrouter_catalog(report, live=args.live) + + failed = report.render() + return 1 if failed else 0 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--no-live", + dest="live", + action="store_false", + help="Skip live API calls — capability resolver only.", + ) + parser.set_defaults(live=True) + parser.add_argument("--skip-chat", action="store_true") + parser.add_argument("--skip-vision", action="store_true") + parser.add_argument("--skip-image-gen", action="store_true") + parser.add_argument("--skip-openrouter", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + sys.exit(asyncio.run(main(args))) diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py new file mode 100644 index 000000000..c9f18d77d --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py @@ -0,0 +1,110 @@ +"""Unit tests for ``supports_image_input`` derivation on BYOK chat config +endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). + +There is no DB column for ``supports_image_input`` on +``NewLLMConfig`` — the value is resolved at the API boundary by +``derive_supports_image_input`` so the new-chat selector / streaming +task can read the same field shape regardless of source (BYOK vs YAML +vs OpenRouter dynamic). Default-allow on unknown so we don't lock the +user out of their own model choice. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from app.db import LiteLLMProvider +from app.routes import new_llm_config_routes + +pytestmark = pytest.mark.unit + + +def _byok_row( + *, + id_: int, + model_name: str, + base_model: str | None = None, + provider: LiteLLMProvider = LiteLLMProvider.OPENAI, + custom_provider: str | None = None, +) -> object: + """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` + walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. + + ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's + enum validator accepts it — same as the ORM row would carry.""" + return SimpleNamespace( + id=id_, + name=f"BYOK-{id_}", + description=None, + provider=provider, + custom_provider=custom_provider, + model_name=model_name, + api_key="sk-byok", + api_base=None, + litellm_params={"base_model": base_model} if base_model else None, + system_instructions="", + use_default_system_instructions=True, + citations_enabled=True, + created_at=datetime.now(tz=UTC), + search_space_id=42, + user_id=uuid4(), + ) + + +def test_serialize_byok_known_vision_model_resolves_true(): + """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> + True. The serialized row carries that value through to the + ``NewLLMConfigRead`` schema.""" + row = _byok_row(id_=1, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + assert serialized.id == 1 + assert serialized.model_name == "gpt-4o" + + +def test_serialize_byok_unknown_model_default_allows(): + """Unknown / unmapped: default-allow. The streaming-task safety net + is the actual block, and it requires LiteLLM to *explicitly* say + text-only — so a brand new BYOK model should not be pre-judged.""" + row = _byok_row( + id_=2, + model_name="brand-new-model-x9-unmapped", + provider=LiteLLMProvider.CUSTOM, + custom_provider="brand_new_proxy", + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_uses_base_model_when_present(): + """Azure-style: ``model_name`` is the deployment id, ``base_model`` + inside ``litellm_params`` is the canonical sku LiteLLM knows. The + helper must consult ``base_model`` first or unrecognised deployment + ids would shadow the real capability.""" + row = _byok_row( + id_=3, + model_name="my-azure-deployment-id-no-litellm-knows-this", + base_model="gpt-4o", + provider=LiteLLMProvider.AZURE_OPENAI, + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_returns_pydantic_read_model(): + """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so + the schema additions are guaranteed to be present in the API + surface. This guards against a future regression where someone + deletes the augmentation step and falls back to ORM passthrough.""" + from app.schemas import NewLLMConfigRead + + row = _byok_row(id_=4, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py new file mode 100644 index 000000000..2b6c76485 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -0,0 +1,184 @@ +"""Unit tests for ``is_premium`` derivation on the global image-gen and +vision-LLM list endpoints. + +Chat globals (``GET /global-llm-configs``) already emit +``is_premium = (billing_tier == "premium")``. Image and vision did not, +which made the new-chat ``model-selector`` render the Free/Premium badge +on the Chat tab but skip it on the Image and Vision tabs (the selector +keys its badge logic off ``is_premium``). These tests pin parity: + +* YAML free entry → ``is_premium=False`` +* YAML premium entry → ``is_premium=True`` +* OpenRouter dynamic premium entry → ``is_premium=True`` +* Auto stub (always emitted when at least one config is present) + → ``is_premium=False`` +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_IMAGE_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "DALL-E 3", + "provider": "OPENAI", + "model_name": "dall-e-3", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "GPT-Image 1 (premium)", + "provider": "OPENAI", + "model_name": "gpt-image-1", + "api_key": "sk-test", + "billing_tier": "premium", + }, + { + "id": -20_001, + "name": "google/gemini-2.5-flash-image (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +_VISION_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o Vision", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "Claude 3.5 Sonnet (premium)", + "provider": "ANTHROPIC", + "model_name": "claude-3-5-sonnet", + "api_key": "sk-ant-test", + "billing_tier": "premium", + }, + { + "id": -30_001, + "name": "openai/gpt-4o (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +# ============================================================================= +# Image generation +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_emit_is_premium(monkeypatch): + """Each emitted config must carry ``is_premium`` derived server-side + from ``billing_tier``. The Auto stub is always free. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False + ) + + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + # Auto stub is always emitted when at least one global config exists, + # and it must always declare itself free (Auto-mode billing-tier + # surfacing is a separate follow-up). + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + # YAML free entry — ``is_premium=False`` + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + # YAML premium entry — ``is_premium=True`` + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + # OpenRouter dynamic premium entry — same field, same derivation + assert by_id[-20_001]["is_premium"] is True + assert by_id[-20_001]["billing_tier"] == "premium" + + # Every emitted dict (including Auto) must have the field — never missing. + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): + """When there are no global configs at all, the endpoint emits an + empty list (no Auto stub) — Auto mode would have nothing to route to. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + assert payload == [] + + +# ============================================================================= +# Vision LLM +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr( + config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False + ) + + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + assert by_id[-30_001]["is_premium"] is True + assert by_id[-30_001]["billing_tier"] == "premium" + + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py new file mode 100644 index 000000000..b47d9134b --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -0,0 +1,106 @@ +"""Unit tests for ``supports_image_input`` derivation on the chat global +config endpoint (``GET /global-new-llm-configs``). + +Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): + +1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML + loader for operator overrides, or by the OpenRouter integration from + ``architecture.input_modalities``) — wins. +2. ``derive_supports_image_input`` helper — default-allow on unknown + models, only False when LiteLLM / OR modalities are definitive. + +The flag is purely informational at the API boundary. The streaming +task safety net (``is_known_text_only_chat_model``) is the actual block, +and it requires LiteLLM to *explicitly* mark the model as text-only. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o (explicit true)", + "description": "vision-capable, explicit YAML override", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + "supports_image_input": True, + }, + { + "id": -2, + "name": "DeepSeek V3 (explicit false)", + "description": "OpenRouter dynamic — modality-derived false", + "provider": "OPENROUTER", + "model_name": "deepseek/deepseek-v3.2-exp", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "free", + "supports_image_input": False, + }, + { + "id": -10_010, + "name": "Unannotated GPT-4o", + "description": "no flag set — resolver should derive True via LiteLLM", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + # supports_image_input intentionally absent + }, + { + "id": -10_011, + "name": "Unannotated unknown model", + "description": "unmapped — default-allow True", + "provider": "CUSTOM", + "custom_provider": "brand_new_proxy", + "model_name": "brand-new-model-x9", + "api_key": "sk-test", + "billing_tier": "free", + }, +] + + +@pytest.mark.asyncio +async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): + """Each emitted chat config carries ``supports_image_input`` as a + bool. Explicit values win; unannotated entries are resolved via the + helper (default-allow True).""" + from app.config import config + from app.routes import new_llm_config_routes + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) + + payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) + by_id = {c["id"]: c for c in payload} + + # Auto stub: optimistic True so the user can keep Auto selected with + # vision-capable deployments somewhere in the pool. + assert 0 in by_id, "Auto stub should be emitted when configs exist" + assert by_id[0]["supports_image_input"] is True + assert by_id[0]["is_auto_mode"] is True + + # Explicit True is preserved. + assert by_id[-1]["supports_image_input"] is True + + # Explicit False is preserved (the exact failure mode the safety net + # guards against — DeepSeek V3 over OpenRouter would 404 with "No + # endpoints found that support image input"). + assert by_id[-2]["supports_image_input"] is False + + # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. + assert by_id[-10_010]["supports_image_input"] is True + + # Unknown / unmapped model: default-allow rather than pre-judge. + assert by_id[-10_011]["supports_image_input"] is True + + for cfg in payload: + assert "supports_image_input" in cfg, ( + f"supports_image_input missing from {cfg.get('id')}" + ) + assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py new file mode 100644 index 000000000..0e19b80e4 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -0,0 +1,286 @@ +"""Image-aware extension of the Auto-pin resolver. + +When the current chat turn carries an ``image_url`` block, the pin +resolver must: + +1. Filter the candidate pool to vision-capable cfgs so a freshly + selected pin can never be text-only. +2. Treat any existing pin whose capability is False as invalid (force + re-pin), even when it would otherwise be reused as the thread's + stable model. +3. Raise ``ValueError`` (mapped to the friendly + ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming + task) when no vision-capable cfg is available — instead of silently + pinning text-only and 404-ing at the provider. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + clear_healthy, + clear_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _reset_caches(): + clear_runtime_cooldown() + clear_healthy() + yield + clear_runtime_cooldown() + clear_healthy() + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread(*, pinned: int | None = None): + return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned) + + +def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"vision-{id_}", + "api_key": "k", + "billing_tier": tier, + "supports_image_input": True, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"text-{id_}", + "api_key": "k", + "billing_tier": tier, + # Higher quality than the vision cfgs — so a bug that ignores + # the image flag would surface as the resolver picking this one. + "supports_image_input": False, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +async def _premium_allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + +@pytest.mark.asyncio +async def test_image_turn_filters_out_text_only_candidates(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + # The thread should be pinned to the vision cfg even though the + # text-only cfg has a higher quality score. + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): + """An existing text-only pin must be invalidated when the next turn + requires image input. The non-image path would happily reuse it.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_reuses_existing_vision_pin(monkeypatch): + """If the thread is already pinned to a vision-capable cfg, reuse it + — same as the non-image path. Image-aware filtering must not force + spurious re-pins.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-2)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): + """The friendly-error path: no vision-capable cfg in the pool -> raise + ``ValueError`` whose message contains ``vision-capable`` so the + streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _text_only_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + with pytest.raises(ValueError, match="vision-capable"): + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + +@pytest.mark.asyncio +async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): + """Regression guard: the image flag must default False and not affect + a normal text-only turn — text-only cfgs remain selectable.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + + +@pytest.mark.asyncio +async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): + """A YAML cfg that omits ``supports_image_input`` falls through to + ``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o`` + that returns True, so the cfg should be a valid candidate.""" + from app.config import config + + session = _FakeSession(_thread()) + cfg_unannotated_vision = { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-4o", # known vision model in LiteLLM map + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "A", + "quality_score": 80, + # NOTE: no supports_image_input key + } + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + assert result.resolved_llm_config_id == -2 diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py index 86de5f23d..c820724ed 100644 --- a/surfsense_backend/tests/unit/services/test_billable_call.py +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -15,6 +15,7 @@ vision LLM extraction: from __future__ import annotations +import asyncio import contextlib from typing import Any from uuid import uuid4 @@ -57,6 +58,9 @@ class _FakeSession: async def commit(self) -> None: self.committed = True + async def rollback(self) -> None: + pass + async def close(self) -> None: pass @@ -71,7 +75,9 @@ async def _fake_shielded_session(): _SESSIONS_USED: list[_FakeSession] = [] -def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None): +def _patch_isolation_layer( + monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None +): """Wire fake reserve/finalize/release/session helpers.""" _SESSIONS_USED.clear() reserve_calls: list[dict[str, Any]] = [] @@ -91,6 +97,8 @@ def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None) async def _fake_finalize( *, db_session, user_id, request_id, actual_micros, reserved_micros ): + if finalize_exc is not None: + raise finalize_exc finalize_calls.append( { "user_id": user_id, @@ -343,6 +351,125 @@ async def test_premium_uses_estimator_when_no_micros_override(monkeypatch): assert spies["reserve"][0]["reserve_micros"] == 12_345 +@pytest.mark.asyncio +async def test_premium_finalize_failure_propagates_and_releases(monkeypatch): + from app.services.billable_calls import BillingSettlementError, billable_call + + class _FinalizeError(RuntimeError): + pass + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult(allowed=True), + finalize_exc=_FinalizeError("db finalize failed"), + ) + user_id = uuid4() + + with pytest.raises(BillingSettlementError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["release"]) == 1 + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _HangingCommitSession(_FakeSession): + async def commit(self) -> None: + await asyncio.sleep(60) + + @contextlib.asynccontextmanager + async def _hanging_session_factory(): + s = _HangingCommitSession() + _SESSIONS_USED.append(s) + yield s + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + billable_session_factory=_hanging_session_factory, + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["finalize"]) == 1 + assert len(spies["record"]) == 1 + assert spies["release"] == [] + + +@pytest.mark.asyncio +async def test_free_audit_failure_is_best_effort(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + async def _failing_record(_session, **_kwargs): + raise RuntimeError("audit insert failed") + + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _failing_record, + raising=False, + ) + + async with billable_call( + user_id=uuid4(), + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + + # --------------------------------------------------------------------------- # Podcast / video-presentation usage_type coverage # --------------------------------------------------------------------------- @@ -387,7 +514,7 @@ async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch): assert len(spies["record"]) == 1 row = spies["record"][0] assert row["usage_type"] == "podcast_generation" - assert row["thread_id"] == 99 + assert row["thread_id"] is None assert row["search_space_id"] == 42 assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py new file mode 100644 index 000000000..9d5fdb190 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -0,0 +1,177 @@ +"""Defense-in-depth: image-gen call sites must not let an empty +``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. + +The bug repro: an OpenRouter image-gen config ships +``api_base=""``. The pre-fix call site in +``image_generation_routes._execute_image_generation`` did +``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which +silently dropped the empty string. LiteLLM then fell back to +``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) +and OpenRouter's ``image_generation/transformation`` appended +``/chat/completions`` to it → 404 ``Resource not found``. + +This test pins the post-fix behaviour: with an empty ``api_base`` in +the config, the call site MUST set ``api_base`` to OpenRouter's public +URL instead of leaving it unset. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): + """The global-config branch (``config_id < 0``) of + ``_execute_image_generation`` must apply the resolver and pin + ``api_base`` to OpenRouter when the config ships an empty string. + """ + from app.routes import image_generation_routes + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", # the original bug shape + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) + + image_gen = MagicMock() + image_gen.image_generation_config_id = cfg["id"] + image_gen.prompt = "test" + image_gen.n = 1 + image_gen.quality = None + image_gen.size = None + image_gen.style = None + image_gen.response_format = None + image_gen.model = None + + search_space = MagicMock() + search_space.image_generation_config_id = cfg["id"] + session = MagicMock() + + with ( + patch.object( + image_generation_routes, + "_get_global_image_gen_config", + return_value=cfg, + ), + patch.object( + image_generation_routes, + "aimage_generation", + side_effect=fake_aimage_generation, + ), + ): + await image_generation_routes._execute_image_generation( + session=session, image_gen=image_gen, search_space=search_space + ) + + # The whole point of the fix: even with empty ``api_base`` in the + # config, we forward OpenRouter's public URL so the call doesn't + # inherit an Azure endpoint. + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +@pytest.mark.asyncio +async def test_generate_image_tool_global_sets_api_base_when_config_empty(): + """Same defense at the agent tool entry point — both surfaces share + the same OpenRouter config payloads.""" + from app.agents.new_chat.tools import generate_image as gi_module + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + response = MagicMock() + response.model_dump.return_value = { + "data": [{"url": "https://example.com/x.png"}] + } + response._hidden_params = {"model": "openrouter/openai/gpt-image-1"} + return response + + search_space = MagicMock() + search_space.id = 1 + search_space.image_generation_config_id = cfg["id"] + + session_cm = AsyncMock() + session = AsyncMock() + session_cm.__aenter__.return_value = session + + scalars = MagicMock() + scalars.first.return_value = search_space + exec_result = MagicMock() + exec_result.scalars.return_value = scalars + session.execute.return_value = exec_result + session.add = MagicMock() + session.commit = AsyncMock() + session.refresh = AsyncMock() + + # ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback. + async def _refresh(obj): + obj.id = 1 + + session.refresh.side_effect = _refresh + + with ( + patch.object(gi_module, "shielded_async_session", return_value=session_cm), + patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object( + gi_module, "aimage_generation", side_effect=fake_aimage_generation + ), + patch.object( + gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0 + ), + ): + tool = gi_module.create_generate_image_tool( + search_space_id=1, db_session=MagicMock() + ) + await tool.ainvoke({"prompt": "a cat", "n": 1}) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +def test_image_gen_router_deployment_sets_api_base_when_config_empty(): + """The Auto-mode router pool must also resolve ``api_base`` when an + OpenRouter config ships an empty string. The deployment dict is fed + straight to ``litellm.Router``, so a missing ``api_base`` would + leak the same way as the direct call sites. + """ + from app.services.image_gen_router_service import ImageGenRouterService + + deployment = ImageGenRouterService._config_to_deployment( + { + "model_name": "openai/gpt-image-1", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index b635b4fe8..88fcf2db3 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -265,6 +265,10 @@ def test_generate_image_gen_configs_filters_by_image_output(): assert c["billing_tier"] in {"free", "premium"} assert c["provider"] == "OPENROUTER" assert c[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't 404 against an inherited Azure endpoint. + assert c["api_base"] == "https://openrouter.ai/api/v1" def test_generate_image_gen_configs_assigns_image_id_offset(): @@ -342,6 +346,10 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output(): assert cfg["input_cost_per_token"] == pytest.approx(5e-6) assert cfg["output_cost_per_token"] == pytest.approx(15e-6) assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't inherit an Azure endpoint. + assert cfg["api_base"] == "https://openrouter.ai/api/v1" def test_generate_vision_llm_configs_drops_chat_only_filters(): diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py new file mode 100644 index 000000000..12cd0a3d5 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py @@ -0,0 +1,107 @@ +"""Unit tests for the shared ``api_base`` resolver. + +The cascade exists so vision and image-gen call sites can't silently +inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) +when an OpenRouter / Groq / etc. config ships an empty string. See +``provider_api_base`` module docstring for the original repro +(OpenRouter image-gen 404-ing against an Azure endpoint). +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_api_base import ( + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) + +pytestmark = pytest.mark.unit + + +def test_config_value_wins_over_defaults(): + """A non-empty config value is always returned verbatim, even when the + provider has a default — the operator gets the last word.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="https://my-openrouter-mirror.example.com/v1", + ) + assert result == "https://my-openrouter-mirror.example.com/v1" + + +def test_provider_key_default_when_config_missing(): + """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own + base URL — the provider-key map must take precedence over the prefix + map so DeepSeek requests don't go to OpenAI.""" + result = resolve_api_base( + provider="DEEPSEEK", + provider_prefix="openai", + config_api_base=None, + ) + assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_provider_prefix_default_when_no_key_default(): + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=None, + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_unknown_provider_returns_none(): + """When neither map matches we return ``None`` so the caller can let + LiteLLM apply its own provider-integration default (Azure deployment + URL, custom-provider URL, etc.).""" + result = resolve_api_base( + provider="SOMETHING_NEW", + provider_prefix="something_new", + config_api_base=None, + ) + assert result is None + + +def test_empty_string_config_treated_as_missing(): + """The original bug: OpenRouter dynamic configs ship ``api_base=""`` + and downstream call sites use ``if cfg.get("api_base"):`` — empty + strings are falsy in Python but the cascade has to step in anyway.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_whitespace_only_config_treated_as_missing(): + """A config value of ``" "`` is a configuration mistake — treat it + as missing instead of forwarding whitespace to LiteLLM (which would + almost certainly 404).""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=" ", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_provider_case_insensitive(): + """Some call sites pass the provider lowercase (DB enum value), others + uppercase (YAML key). Both must resolve.""" + upper = resolve_api_base( + provider="DEEPSEEK", provider_prefix="openai", config_api_base=None + ) + lower = resolve_api_base( + provider="deepseek", provider_prefix="openai", config_api_base=None + ) + assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_all_inputs_none_returns_none(): + assert ( + resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) + is None + ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py new file mode 100644 index 000000000..aac88977f --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -0,0 +1,244 @@ +"""Unit tests for the shared chat-image capability resolver. + +Two resolvers, two intents: + +- ``derive_supports_image_input`` — best-effort True for the catalog and + selector. Default-allow on unknown / unmapped models. The streaming + task safety net never sees this value directly. + +- ``is_known_text_only_chat_model`` — strict opt-out for the safety net. + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False``. Anything else (missing key, exception, + True) returns False so the request flows through to the provider. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import ( + derive_supports_image_input, + is_known_text_only_chat_model, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — OpenRouter modalities path (authoritative) +# --------------------------------------------------------------------------- + + +def test_or_modalities_with_image_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="openai/gpt-4o", + openrouter_input_modalities=["text", "image"], + ) + is True + ) + + +def test_or_modalities_text_only_returns_false(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="deepseek/deepseek-v3.2-exp", + openrouter_input_modalities=["text"], + ) + is False + ) + + +def test_or_modalities_empty_list_returns_false(): + """OR explicitly publishing an empty modality list is a definitive + 'no inputs at all' signal — treat as False rather than falling back + to LiteLLM.""" + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="weird/empty-modalities", + openrouter_input_modalities=[], + ) + is False + ) + + +def test_or_modalities_none_falls_through_to_litellm(): + """``None`` (missing key) is *not* a definitive signal — fall through + to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + openrouter_input_modalities=None, + ) + is True + ) + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — LiteLLM model-map path +# --------------------------------------------------------------------------- + + +def test_litellm_known_vision_model_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + ) + is True + ) + + +def test_litellm_base_model_wins_over_model_name(): + """Azure-style entries pass model_name=deployment_id and put the + canonical sku in litellm_params.base_model. The resolver must + consult base_model first or the deployment id (which LiteLLM + doesn't know) would shadow the real capability.""" + assert ( + derive_supports_image_input( + provider="AZURE_OPENAI", + model_name="my-azure-deployment-id", + base_model="gpt-4o", + ) + is True + ) + + +def test_litellm_unknown_model_default_allows(): + """Default-allow on unknown — the safety net is the actual block.""" + assert ( + derive_supports_image_input( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is True + ) + + +def test_litellm_known_text_only_returns_false(): + """A model that LiteLLM explicitly knows is text-only resolves to + False even via the catalog resolver. ``deepseek-chat`` (the + DeepSeek-V3 chat sku) is in the map without supports_vision and + LiteLLM's `supports_vision` returns False.""" + # Sanity: confirm the helper's negative path. We use a small model + # known not to support vision per the map. + result = derive_supports_image_input( + provider="DEEPSEEK", + model_name="deepseek-chat", + ) + # We accept either False (LiteLLM said explicit no) or True + # (default-allow if the entry isn't mapped on this version) — the + # invariant is that the resolver never *raises* on a known-text-only + # provider/model. The behaviour-binding assertion lives in + # ``test_is_known_text_only_chat_model_explicit_false`` below. + assert isinstance(result, bool) + + +# --------------------------------------------------------------------------- +# is_known_text_only_chat_model — strict opt-out semantics +# --------------------------------------------------------------------------- + + +def test_is_known_text_only_returns_false_for_vision_model(): + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_false_for_unknown_model(): + """Strict opt-out: missing from the map ≠ text-only. The safety net + must NOT fire for an unmapped model — that's the regression we're + fixing.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is False + ) + + +def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): + """LiteLLM's ``get_model_info`` raises freely on parse errors. The + helper swallows the exception and returns False so the safety net + doesn't fire on a transient lookup failure.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise ValueError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): + """Stub LiteLLM's ``get_model_info`` to return an explicit False so + we exercise the opt-out path deterministically. Using a stub keeps + the test stable across LiteLLM map updates.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is True + ) + + +def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) + + +def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): + """A model entry without ``supports_vision`` at all is treated as + 'unknown' — strict opt-out means False.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"max_input_tokens": 8192} # no supports_vision + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py new file mode 100644 index 000000000..71fdee1c7 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py @@ -0,0 +1,281 @@ +"""Unit tests for the chat-catalog ``supports_image_input`` capability flag. + +Capability is sourced from two places, in order of preference: + +1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs + (authoritative — OpenRouter publishes per-model modalities directly). +2. LiteLLM's authoritative model map (``litellm.supports_vision``) for + YAML / BYOK configs that don't carry an explicit operator override. + +The catalog default is *True* (conservative-allow): an unknown / unmapped +model is not pre-judged. The streaming-task safety net +(``is_known_text_only_chat_model``) is the only place a False actually +blocks a request — and it requires LiteLLM to *explicitly* mark the model +as text-only. +""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _OPENROUTER_DYNAMIC_MARKER, + _generate_configs, + _supports_image_input, +) + +pytestmark = pytest.mark.unit + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, +} + + +# --------------------------------------------------------------------------- +# _supports_image_input helper (OpenRouter modalities) +# --------------------------------------------------------------------------- + + +def test_supports_image_input_true_for_multimodal(): + assert ( + _supports_image_input( + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + } + ) + is True + ) + + +def test_supports_image_input_false_for_text_only(): + """The exact failure mode the safety net guards against — DeepSeek V3 + is a text-in/text-out model and would 404 if forwarded image_url.""" + assert ( + _supports_image_input( + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + } + ) + is False + ) + + +def test_supports_image_input_false_when_modalities_missing(): + """Defensive: missing architecture is treated as text-only at the + OpenRouter helper level. The wider catalog resolver + (`derive_supports_image_input`) only consults modalities when they + are non-empty, otherwise it falls back to LiteLLM.""" + assert _supports_image_input({"id": "weird/model"}) is False + assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False + assert ( + _supports_image_input( + {"id": "weird/model", "architecture": {"input_modalities": None}} + ) + is False + ) + + +# --------------------------------------------------------------------------- +# _generate_configs threads the flag onto every emitted chat config +# --------------------------------------------------------------------------- + + +def test_generate_configs_emits_supports_image_input(): + raw = [ + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + gpt = by_model["openai/gpt-4o"] + assert gpt["supports_image_input"] is True + assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True + + deepseek = by_model["deepseek/deepseek-v3.2-exp"] + assert deepseek["supports_image_input"] is False + assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True + + +# --------------------------------------------------------------------------- +# YAML loader: defer to derive_supports_image_input on unannotated entries +# --------------------------------------------------------------------------- + + +def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch): + """The regression case: an Azure GPT-5.x YAML entry without a + ``supports_image_input`` override should resolve to True via LiteLLM's + model map (which says ``supports_vision: true``). Previously this + defaulted to False, blocking every image turn for vision-capable + YAML configs.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -2 + name: Azure GPT-4o + provider: AZURE_OPENAI + model_name: gpt-4o + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch): + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: GPT-4o + provider: OPENAI + model_name: gpt-4o + api_key: sk-test + supports_image_input: false +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + # Operator override always wins, even against LiteLLM's True. + assert configs[0]["supports_image_input"] is False + + +def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch): + """Unknown / unmapped model in YAML: default-allow. The streaming + safety net (which requires an explicit-False from LiteLLM) is the + only place a real block happens, so we don't lock the user out of + a freshly added third-party entry the catalog can't introspect.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: Some Brand New Model + provider: CUSTOM + custom_provider: brand_new_proxy + model_name: brand-new-model-x9 + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +# --------------------------------------------------------------------------- +# AgentConfig threads the flag through both YAML and Auto / BYOK +# --------------------------------------------------------------------------- + + +def test_agent_config_from_yaml_explicit_overrides_resolver(): + from app.agents.new_chat.llm_config import AgentConfig + + cfg_text_only = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "Text Only Override", + "provider": "openai", + "model_name": "gpt-4o", # Capable per LiteLLM, but operator says no. + "api_key": "sk-test", + "supports_image_input": False, + } + ) + cfg_explicit_vision = AgentConfig.from_yaml_config( + { + "id": -2, + "name": "GPT-4o", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + "supports_image_input": True, + } + ) + assert cfg_text_only.supports_image_input is False + assert cfg_explicit_vision.supports_image_input is True + + +def test_agent_config_from_yaml_unannotated_uses_resolver(): + """Without an explicit YAML key, AgentConfig defers to the catalog + resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True.""" + from app.agents.new_chat.llm_config import AgentConfig + + cfg = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "GPT-4o (no override)", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + } + ) + assert cfg.supports_image_input is True + + +def test_agent_config_auto_mode_supports_image_input(): + """Auto routes across the pool. We optimistically allow image input + so users can keep their selection on Auto with a vision-capable + deployment somewhere in the pool. The router's own `allowed_fails` + handles non-vision deployments via fallback.""" + from app.agents.new_chat.llm_config import AgentConfig + + auto = AgentConfig.from_auto_mode() + assert auto.supports_image_input is True diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py new file mode 100644 index 000000000..b8ba9d80c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -0,0 +1,89 @@ +"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` +defaults from ``litellm.api_base`` either. + +Vision shares the same shape as image-gen — global YAML / OpenRouter +dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` +call sites would silently drop the empty string and inherit +``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on +construction so we test the kwargs we hand to it instead. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_get_vision_llm_global_openrouter_sets_api_base(): + """Global negative-ID branch: an OpenRouter vision config with + ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with + ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, + never silently absent.""" + from app.services import llm_service + + cfg = { + "id": -30_001, + "name": "GPT-4o Vision (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + "billing_tier": "free", + } + + search_space = MagicMock() + search_space.id = 1 + search_space.user_id = "user-x" + search_space.vision_llm_config_id = cfg["id"] + + session = AsyncMock() + scalars = MagicMock() + scalars.first.return_value = search_space + result = MagicMock() + result.scalars.return_value = scalars + session.execute.return_value = result + + captured: dict = {} + + class FakeSanitized: + def __init__(self, **kwargs): + captured.update(kwargs) + + with ( + patch( + "app.services.vision_llm_router_service.get_global_vision_llm_config", + return_value=cfg, + ), + patch( + "app.agents.new_chat.llm_config.SanitizedChatLiteLLM", + new=FakeSanitized, + ), + ): + await llm_service.get_vision_llm(session=session, search_space_id=1) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-4o" + + +def test_vision_router_deployment_sets_api_base_when_config_empty(): + """Auto-mode vision router: deployments are fed to ``litellm.Router``, + so the resolver has to apply at deployment construction time too.""" + from app.services.vision_llm_router_service import VisionLLMRouterService + + deployment = VisionLLMRouterService._config_to_deployment( + { + "model_name": "openai/gpt-4o", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py new file mode 100644 index 000000000..a5bb3f58a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -0,0 +1,318 @@ +"""Regression tests for ``run_async_celery_task``. + +These tests pin down the production bug observed on 2026-05-02 where +the video-presentation Celery task hung at ``[billable_call] finalize`` +because the shared ``app.db.engine`` had pooled asyncpg connections +bound to a *previous* task's now-closed event loop. Reusing such a +connection on a fresh loop crashes inside ``pool_pre_ping`` with:: + + AttributeError: 'NoneType' object has no attribute 'send' + +(the proactor is None because the loop is gone) and can hang forever +inside the asyncpg ``Connection._cancel`` cleanup coroutine. + +The fix is ``run_async_celery_task``: a small helper that runs every +async celery task body inside a fresh event loop and disposes the +shared engine pool both before (defends against a previous task that +crashed) and after (releases connections we opened on this loop). + +Tests here exercise the helper with a stub engine that records +``dispose()`` calls and panics if a coroutine produced by one loop is +awaited on another — mirroring the real asyncpg behaviour. +""" + +from __future__ import annotations + +import asyncio +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from unittest.mock import patch + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Stub engine that emulates the asyncpg-on-stale-loop crash +# --------------------------------------------------------------------------- + + +class _StaleLoopEngine: + """Tiny stand-in for ``app.db.engine`` that tracks dispose() calls. + + ``dispose()`` is async (matches ``AsyncEngine.dispose``) and records + the running event loop id so tests can assert it ran on *each* + fresh loop. + """ + + def __init__(self) -> None: + self.dispose_loop_ids: list[int] = [] + + async def dispose(self) -> None: + loop = asyncio.get_running_loop() + self.dispose_loop_ids.append(id(loop)) + + +@contextmanager +def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]: + """Patch ``from app.db import engine as shared_engine`` lookup. + + The helper imports lazily inside the function body, so we have to + patch the attribute on the already-loaded ``app.db`` module. + """ + import app.db as app_db + + original = getattr(app_db, "engine", None) + app_db.engine = stub # type: ignore[attr-defined] + try: + yield + finally: + if original is None: + with pytest.raises(AttributeError): + _ = app_db.engine + else: + app_db.engine = original # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_runner_returns_value_and_disposes_engine_around_call() -> None: + """Happy path: the coroutine result is returned, and the shared + engine is disposed both before and after the task body runs. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> str: + # Engine should already have been disposed once before we run. + assert len(stub.dispose_loop_ids) == 1 + return "ok" + + with _patch_shared_engine(stub): + result = run_async_celery_task(_body) + + assert result == "ok" + # Once before the body, once after (in finally). + assert len(stub.dispose_loop_ids) == 2 + # Both disposes ran on the SAME (fresh) loop the task body used. + assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1] + + +def test_runner_creates_fresh_loop_per_invocation() -> None: + """Each call must spin its own loop. Without this guarantee a + previous task's loop would be reused and the asyncpg-stale-loop + crash would never be avoided. + """ + import app.tasks.celery_tasks as celery_tasks_pkg + + stub = _StaleLoopEngine() + new_loop_calls = 0 + closed_loops: list[bool] = [] + + real_new_event_loop = asyncio.new_event_loop + + def _counting_new_loop() -> asyncio.AbstractEventLoop: + nonlocal new_loop_calls + new_loop_calls += 1 + loop = real_new_event_loop() + # Hook close() so we can verify each loop was closed properly + # before the next one was created. + original_close = loop.close + + def _tracked_close() -> None: + closed_loops.append(True) + original_close() + + loop.close = _tracked_close # type: ignore[method-assign] + return loop + + async def _body() -> None: + # Loop is alive and current at body execution time. + running = asyncio.get_running_loop() + assert not running.is_closed() + + with ( + _patch_shared_engine(stub), + patch.object(asyncio, "new_event_loop", _counting_new_loop), + ): + for _ in range(3): + celery_tasks_pkg.run_async_celery_task(_body) + + assert new_loop_calls == 3 + assert closed_loops == [True, True, True] + # Each invocation disposed twice (before + after). + assert len(stub.dispose_loop_ids) == 6 + + +def test_runner_disposes_engine_even_when_body_raises() -> None: + """Cleanup MUST run on the failure path too — otherwise stale + connections leak into the next task and cause the original hang. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + class _BoomError(RuntimeError): + pass + + async def _body() -> None: + raise _BoomError("kaboom") + + with _patch_shared_engine(stub), pytest.raises(_BoomError): + run_async_celery_task(_body) + + assert len(stub.dispose_loop_ids) == 2 # before + after still ran + + +def test_runner_swallows_dispose_errors() -> None: + """A flaky engine.dispose() must NEVER take down a celery task. + + Production scenario: the very first dispose (before the body runs) + might hit a partially-initialised engine; the helper logs and + moves on. The task body still runs; the result is still returned. + """ + from app.tasks.celery_tasks import run_async_celery_task + + class _AngryEngine: + def __init__(self) -> None: + self.calls = 0 + + async def dispose(self) -> None: + self.calls += 1 + raise RuntimeError("dispose() blew up") + + stub = _AngryEngine() + + async def _body() -> int: + return 42 + + with _patch_shared_engine(stub): + assert run_async_celery_task(_body) == 42 + + assert stub.calls == 2 # before + after both attempted + + +def test_runner_propagates_value_from_async_body() -> None: + """Sanity: pass-through of any pickleable celery return value.""" + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> dict[str, object]: + return {"status": "ready", "video_presentation_id": 19} + + with _patch_shared_engine(stub): + out = run_async_celery_task(_body) + + assert out == {"status": "ready", "video_presentation_id": 19} + + +def test_video_presentation_task_uses_runner_helper() -> None: + """Defence-in-depth: confirm the celery task module imports + ``run_async_celery_task``. If a future refactor inlines a + ``loop = asyncio.new_event_loop(); ... loop.close()`` block again, + the original hang will return. + """ + # The module's task body should not contain a manual new_event_loop + # call — that's exactly what the helper exists to centralise. + import inspect + + from app.tasks.celery_tasks import video_presentation_tasks + + src = inspect.getsource(video_presentation_tasks) + assert "run_async_celery_task" in src, ( + "video_presentation_tasks.py must use run_async_celery_task; " + "manual asyncio.new_event_loop() in a celery task hangs on the " + "shared SQLAlchemy pool when reused across tasks." + ) + assert "asyncio.new_event_loop" not in src, ( + "video_presentation_tasks.py contains a raw asyncio.new_event_loop " + "call — route every async task through run_async_celery_task to " + "avoid the stale-pool hang." + ) + + +def test_podcast_task_uses_runner_helper() -> None: + """Symmetric assertion for the podcast task — same root cause, same + fix, same regression risk. + """ + import inspect + + from app.tasks.celery_tasks import podcast_tasks + + src = inspect.getsource(podcast_tasks) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src + + +def test_runner_runs_shutdown_asyncgens_before_close() -> None: + """If the task body created any async generators that didn't get + fully iterated, we must still call ``loop.shutdown_asyncgens()`` + before closing — otherwise we leak event-loop bound resources + that re-emerge as ``RuntimeError: Event loop is closed`` later. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _agen(): + try: + yield 1 + yield 2 + finally: + pass + + async def _body() -> None: + # Iterate the agen partially, then leave it dangling — exactly + # the situation shutdown_asyncgens() is designed to clean up. + async for v in _agen(): + if v == 1: + break + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # By the time the helper returns, garbage collection + shutdown_asyncgens + # should have ensured no live async-gen references remain. We don't + # assert agen.closed directly (it depends on GC ordering); the real + # contract is "no warnings, no event-loop-closed errors". A successful + # second invocation proves the loop was cleaned up properly. + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # Force a GC pass to surface any 'coroutine was never awaited' + # warnings that would indicate the cleanup is broken. + gc.collect() + + +def test_runner_uses_proactor_loop_on_windows() -> None: + """On Windows the celery worker preselects a Proactor policy so + subprocess (ffmpeg) calls work. The helper must not silently fall + back to a Selector loop and re-break video/podcast generation. + """ + if not sys.platform.startswith("win"): + pytest.skip("Windows-specific event-loop policy assertion") + + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + # Mirror the policy set at the top of every Windows celery task. + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + observed: list[str] = [] + + async def _body() -> None: + observed.append(type(asyncio.get_running_loop()).__name__) + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + assert observed == ["ProactorEventLoop"] diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py index 38d6ba2ca..699297df1 100644 --- a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -113,6 +113,19 @@ async def _denying_billable_call(**kwargs): yield SimpleNamespace() # pragma: no cover — for grammar only +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -187,8 +200,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp call["quota_reserve_micros_override"] == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS ) - assert call["thread_id"] == 99 - assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "podcast_id": 7, + "title": "Test Podcast", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) @pytest.mark.asyncio @@ -279,6 +299,49 @@ async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypat assert graph_invoked == [] # Graph never ran on denied reservation. +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=10) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr( + podcast_tasks, "billable_call", _settlement_failing_billable_call + ) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=10, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 10, + "reason": "billing_settlement_failed", + } + assert podcast.status == PodcastStatus.FAILED + + @pytest.mark.asyncio async def test_resolver_failure_marks_podcast_failed(monkeypatch): """If the resolver raises (e.g. search-space deleted), the task fails diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py new file mode 100644 index 000000000..792d059b0 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -0,0 +1,119 @@ +"""Predicate-level test for the chat streaming safety net. + +The safety net in ``stream_new_chat`` rejects an image turn early with +a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the +selected model is *known* to be text-only. The earlier round of this +work used a strict opt-in flag (``supports_image_input`` defaulting to +False on every YAML entry) which blocked vision-capable Azure GPT-5.x +deployments — this is the regression we're fixing. + +The new predicate is :func:`is_known_text_only_chat_model`, which +returns True only when LiteLLM's authoritative model map *explicitly* +sets ``supports_vision=False``. Anything else (vision True, missing +key, exception) returns False so the request flows through to the +provider. + +We exercise the predicate directly here rather than driving the full +``stream_new_chat`` generator — covering the gate in isolation keeps +the test focused on the regression while the generator's wider behavior +is exercised by the integration suite. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import is_known_text_only_chat_model + +pytestmark = pytest.mark.unit + + +def test_safety_net_does_not_fire_for_azure_gpt_4o(): + """Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is + vision-capable per LiteLLM's model map. The previous round's + blanket-False default blocked it; the new predicate must NOT mark + it text-only.""" + assert ( + is_known_text_only_chat_model( + provider="AZURE_OPENAI", + model_name="my-azure-deployment", + base_model="gpt-4o", + ) + is False + ) + + +def test_safety_net_does_not_fire_for_unknown_model(): + """Default-pass on unknown — the safety net only blocks definitive + text-only confirmations. A freshly added third-party model that + LiteLLM doesn't know about must flow through to the provider.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + custom_provider="brand_new_proxy", + model_name="brand-new-model-x9", + ) + is False + ) + + +def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): + """Transient ``litellm.get_model_info`` exception ≠ block. The + helper swallows the error and treats it as 'unknown' → False.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise RuntimeError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_safety_net_fires_only_on_explicit_false(monkeypatch): + """Stub LiteLLM to assert the only path that returns True is the + explicit ``supports_vision=False`` case. Anything else (True, + None, missing key) returns False from the predicate.""" + import app.services.provider_capabilities as pc + + def _info_explicit_false(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="text-only-stub", + ) + is True + ) + + def _info_true(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="vision-stub", + ) + is False + ) + + def _info_missing(**_kwargs): + return {"max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="missing-key-stub", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py index 671f57ae4..423b64ddb 100644 --- a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -105,6 +105,19 @@ async def _denying_billable_call(**kwargs): yield SimpleNamespace() # pragma: no cover +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -176,11 +189,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp call["quota_reserve_micros_override"] == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS ) - assert call["thread_id"] == 99 + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call assert call["call_details"] == { "video_presentation_id": 11, "title": "Test Presentation", + "thread_id": 99, } + assert callable(call["billable_session_factory"]) @pytest.mark.asyncio @@ -280,6 +297,57 @@ async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch assert graph_invoked == [] +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=14) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, + "billable_call", + _settlement_failing_billable_call, + ) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=14, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 14, + "reason": "billing_settlement_failed", + } + assert video.status == VideoPresentationStatus.FAILED + + @pytest.mark.asyncio async def test_resolver_failure_marks_video_failed(monkeypatch): from app.db import VideoPresentationStatus diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index ffb0e4dc8..3b9d9a526 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -477,9 +477,7 @@ const MessageInfoDropdown: FC = () => { </span> <span className="text-xs text-muted-foreground"> {counts.total_tokens.toLocaleString()} tokens - {costMicros && costMicros > 0 - ? ` · ${formatTurnCost(costMicros)}` - : ""} + {costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""} </span> </ActionBarMorePrimitive.Item> ); diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 1a0f8c5ba..44f3feb7a 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -19,6 +19,7 @@ import { import type React from "react"; import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { globalImageGenConfigsAtom, imageGenConfigsAtom, @@ -461,6 +462,18 @@ export function ModelSelector({ const { data: visionUserConfigs, isLoading: visionUserLoading } = useAtomValue(visionLLMConfigsAtom); + // Pending image attachments on the composer. Used to surface an + // amber "No image" hint on chat models the catalog reports as + // non-vision (`supports_image_input=false`) when the next message + // will carry an image. The hint is purely advisory: selection, + // focus, and click handling are unaffected. The backend's safety + // net (`is_known_text_only_chat_model`) is the actual block, and + // it only fires when LiteLLM *explicitly* marks a model as + // text-only — so a model that's secretly capable but hasn't been + // annotated will still flow through to the provider. + const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); + const hasPendingImages = pendingUserImageUrls.length > 0; + const isLoading = llmUserLoading || llmGlobalLoading || @@ -984,6 +997,21 @@ export function ModelSelector({ const isSelected = getSelectedId() === config.id; const isFocused = focusedIndex === index; const hasCitations = "citations_enabled" in config && !!config.citations_enabled; + // Chat-tab only: surface an amber "No image" hint when the + // composer carries images and the catalog reports the model as + // non-vision. This is purely advisory — selection is *not* + // blocked. The backend's narrow safety net + // (`is_known_text_only_chat_model`) is the source of truth for + // rejecting image turns, and it only fires when LiteLLM + // explicitly marks the model as text-only. A model surfaced as + // `supports_image_input=false` here may still be capable in + // practice (unknown / unmapped LiteLLM entry), so we let the + // user pick it and the provider response decide. + const isImageIncompatibleChatModel = + activeTab === "llm" && + hasPendingImages && + "supports_image_input" in config && + (config as Record<string, unknown>).supports_image_input === false; return ( <div @@ -992,6 +1020,11 @@ export function ModelSelector({ role="option" tabIndex={isMobile ? -1 : 0} aria-selected={isSelected} + title={ + isImageIncompatibleChatModel + ? "This model is reported as text-only. You can still pick it; the provider may reject image turns." + : undefined + } onClick={() => handleSelectItem(item)} onKeyDown={ isMobile @@ -1005,9 +1038,8 @@ export function ModelSelector({ } onMouseEnter={() => setFocusedIndex(index)} className={cn( - "group flex items-center gap-2.5 px-3 py-2 rounded-xl cursor-pointer", - "transition-all duration-150 mx-2", - "hover:bg-accent/40", + "group flex items-center gap-2.5 px-3 py-2 rounded-xl", + "transition-all duration-150 mx-2 cursor-pointer hover:bg-accent/40", isSelected && "bg-primary/6 dark:bg-primary/8", isFocused && "bg-accent/50" )} @@ -1053,6 +1085,14 @@ export function ModelSelector({ Free </Badge> ) : null} + {isImageIncompatibleChatModel && ( + <Badge + variant="secondary" + className="text-[9px] px-1 py-0 h-3.5 bg-amber-100 text-amber-700 dark:bg-amber-900/50 dark:text-amber-300 border-0" + > + No image + </Badge> + )} </div> <div className="flex items-center gap-1.5 mt-0.5"> <span className="text-xs text-muted-foreground truncate"> diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 127b79167..156ef9134 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -250,8 +250,8 @@ function PricingFAQ() { Frequently Asked Questions </h2> <p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground"> - Everything you need to know about SurfSense pages, premium credit, and billing. - Can't find what you need? Reach out at{" "} + Everything you need to know about SurfSense pages, premium credit, and billing. Can't + find what you need? Reach out at{" "} <a href="mailto:rohan@surfsense.com" className="text-blue-500 underline"> rohan@surfsense.com </a> diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index ced97464e..d4afa698b 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -22,6 +22,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; @@ -190,8 +191,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { ? "model" : "models"} </span>{" "} - available from your administrator.{" "} - {(() => { + available from your administrator. {(() => { const nonAuto = globalConfigs.filter( (g) => !("is_auto_mode" in g && g.is_auto_mode) ); @@ -214,6 +214,75 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { </Alert> )} + {/* Global Image Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( + <div className="space-y-3"> + <h3 className="text-xs md:text-sm font-semibold text-muted-foreground"> + Global Image Models + </h3> + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + <Card + key={cfg.id} + className="border-border/60 bg-muted/20 overflow-hidden h-full" + > + <CardContent className="p-4 flex flex-col gap-3 h-full"> + <div className="flex items-center gap-2 min-w-0"> + <div className="shrink-0"> + {getProviderIcon(cfg.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1 flex items-center gap-1.5"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {cfg.name} + </h4> + {isPremium ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0" + > + Premium + </Badge> + ) : ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0" + > + Free + </Badge> + )} + </div> + </div> + {cfg.description && ( + <p className="text-[11px] text-muted-foreground/70 line-clamp-2"> + {cfg.description} + </p> + )} + <div className="flex items-center pt-2 border-t border-border/40 mt-auto"> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {cfg.model_name} + </span> + </div> + </CardContent> + </Card> + ); + })} + </div> + </div> + )} + {/* Loading Skeleton */} {isLoading && ( <div className="space-y-4 md:space-y-6"> diff --git a/surfsense_web/components/settings/more-pages-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx index 8de61b0c7..5635c3314 100644 --- a/surfsense_web/components/settings/more-pages-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -70,9 +70,7 @@ export function MorePagesContent() { <div className="w-full space-y-5"> <div className="text-center"> <h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2> - <p className="mt-1 text-sm text-muted-foreground"> - Earn bonus pages by completing tasks - </p> + <p className="mt-1 text-sm text-muted-foreground">Earn bonus pages by completing tasks</p> </div> <div className="space-y-2"> diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 886d71008..34aa531fd 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -22,6 +22,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; @@ -191,8 +192,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { ? "model" : "models"} </span>{" "} - available from your administrator.{" "} - {(() => { + available from your administrator. {(() => { const nonAuto = globalConfigs.filter( (g) => !("is_auto_mode" in g && g.is_auto_mode) ); @@ -215,6 +215,75 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { </Alert> )} + {/* Global Vision Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( + <div className="space-y-3"> + <h3 className="text-xs md:text-sm font-semibold text-muted-foreground"> + Global Vision Models + </h3> + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + <Card + key={cfg.id} + className="border-border/60 bg-muted/20 overflow-hidden h-full" + > + <CardContent className="p-4 flex flex-col gap-3 h-full"> + <div className="flex items-center gap-2 min-w-0"> + <div className="shrink-0"> + {getProviderIcon(cfg.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1 flex items-center gap-1.5"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {cfg.name} + </h4> + {isPremium ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0" + > + Premium + </Badge> + ) : ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0" + > + Free + </Badge> + )} + </div> + </div> + {cfg.description && ( + <p className="text-[11px] text-muted-foreground/70 line-clamp-2"> + {cfg.description} + </p> + )} + <div className="flex items-center pt-2 border-t border-border/40 mt-auto"> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {cfg.model_name} + </span> + </div> + </CardContent> + </Card> + ); + })} + </div> + </div> + )} + {isLoading && ( <div className="space-y-4 md:space-y-6"> <div className="space-y-4"> diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx index 02f53efad..e8fff2873 100644 --- a/surfsense_web/components/tool-ui/generate-podcast.tsx +++ b/surfsense_web/components/tool-ui/generate-podcast.tsx @@ -416,9 +416,19 @@ export const GeneratePodcastToolUI = ({ return <PodcastErrorState title={title} error={result.error || "Generation failed"} />; } - // Already generating - show simple warning, don't create another poller - // The FIRST tool call will display the podcast when ready - // (new: "generating", legacy: "already_generating") + // Pending/generating rows have a stable podcast_id, so the card can poll + // independently while the chat stream finishes. + if ( + (result.status === "pending" || + result.status === "generating" || + result.status === "processing") && + result.podcast_id + ) { + return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />; + } + + // Legacy duplicate/no-ID result - show a simple warning, don't create + // another poller. The first tool call will display the podcast when ready. if (result.status === "generating" || result.status === "already_generating") { return ( <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> @@ -432,11 +442,6 @@ export const GeneratePodcastToolUI = ({ ); } - // Pending - poll for completion (new: "pending" with podcast_id) - if (result.status === "pending" && result.podcast_id) { - return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />; - } - // Ready with podcast_id (new: "ready", legacy: "success") if ((result.status === "ready" || result.status === "success") && result.podcast_id) { return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />; diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx index 790e5c00e..f72cb3a42 100644 --- a/surfsense_web/contexts/login-gate.tsx +++ b/surfsense_web/contexts/login-gate.tsx @@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) { <DialogHeader> <DialogTitle>Create a free account to {feature}</DialogTitle> <DialogDescription> - Get $5 of premium credit, save chat history, upload documents, use all AI tools, - and connect 30+ integrations. + Get $5 of premium credit, save chat history, upload documents, use all AI tools, and + connect 30+ integrations. </DialogDescription> </DialogHeader> <DialogFooter className="flex flex-col gap-2 sm:flex-row"> diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index 2d6b70eda..b52b98ae4 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -65,6 +65,13 @@ export const newLLMConfig = z.object({ created_at: z.string(), search_space_id: z.number(), user_id: z.string(), + + // Capability flag — derived server-side at the route boundary from + // LiteLLM's authoritative model map. There is no DB column. Default + // `true` is the conservative-allow stance for unknown / unmapped + // BYOK rows; the streaming-task safety net is the only place a + // `false` actually blocks a request. + supports_image_input: z.boolean().default(true), }); /** @@ -74,11 +81,16 @@ export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true }); /** * Create NewLLMConfig + * + * `supports_image_input` is omitted because it is derived server-side + * from LiteLLM's model map at read time — there is no DB column to + * persist a client-supplied value into. */ export const createNewLLMConfigRequest = newLLMConfig.omit({ id: true, created_at: true, user_id: true, + supports_image_input: true, }); export const createNewLLMConfigResponse = newLLMConfig; @@ -114,6 +126,8 @@ export const updateNewLLMConfigRequest = z.object({ created_at: true, search_space_id: true, user_id: true, + // Derived server-side; not part of the writable surface. + supports_image_input: true, }) .partial(), }); @@ -172,6 +186,16 @@ export const globalNewLLMConfig = z.object({ seo_title: z.string().nullable().optional(), seo_description: z.string().nullable().optional(), quota_reserve_tokens: z.number().nullable().optional(), + // Capability flag — true when the model can accept image inputs. + // Resolved server-side (OpenRouter dynamic configs use the OR + // `architecture.input_modalities` field; YAML / BYOK use LiteLLM's + // authoritative `supports_vision` map). The chat selector renders + // an amber "No image" hint when this is false and there are + // pending image attachments, but does not block selection — the + // backend safety net only rejects when LiteLLM *explicitly* marks + // the model as text-only, so unknown / new models still flow + // through. Default `true` matches that conservative-allow stance. + supports_image_input: z.boolean().default(true), }); export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig); @@ -259,6 +283,9 @@ export const globalImageGenConfig = z.object({ is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for image-gen too. + is_premium: z.boolean().default(false), quota_reserve_micros: z.number().nullable().optional(), }); @@ -341,6 +368,9 @@ export const globalVisionLLMConfig = z.object({ is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for vision too. + is_premium: z.boolean().default(false), quota_reserve_tokens: z.number().nullable().optional(), input_cost_per_token: z.number().nullable().optional(), output_cost_per_token: z.number().nullable().optional(), diff --git a/surfsense_web/next.config.ts b/surfsense_web/next.config.ts index 5414d548d..6cfcb5187 100644 --- a/surfsense_web/next.config.ts +++ b/surfsense_web/next.config.ts @@ -18,6 +18,12 @@ const nextConfig: NextConfig = { }, images: { remotePatterns: [ + { + protocol: "http", + hostname: "localhost", + port: "8000", + pathname: "/api/v1/image-generations/**", + }, { protocol: "https", hostname: "**", From cea8618aed74840fe7d48a669dfd4fc07e0039cc Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 21:16:03 -0700 Subject: [PATCH 295/299] fix: fixed composio issues --- .../new_chat/tools/gmail/composio_helpers.py | 41 ++ .../new_chat/tools/gmail/create_draft.py | 61 ++- .../agents/new_chat/tools/gmail/read_email.py | 48 +++ .../new_chat/tools/gmail/search_emails.py | 72 +++- .../agents/new_chat/tools/gmail/send_email.py | 61 ++- .../new_chat/tools/gmail/trash_email.py | 50 ++- .../new_chat/tools/gmail/update_draft.py | 116 ++++-- .../tools/google_calendar/create_event.py | 68 +++- .../tools/google_calendar/delete_event.py | 50 ++- .../tools/google_calendar/search_events.py | 99 +++-- .../tools/google_calendar/update_event.py | 84 +++-- .../tools/google_drive/create_file.py | 60 ++- .../new_chat/tools/google_drive/trash_file.py | 32 +- .../app/agents/new_chat/tools/hitl.py | 1 + .../app/routes/composio_routes.py | 41 +- .../app/services/composio_service.py | 105 +++++- .../services/gmail/tool_metadata_service.py | 122 +++++- .../google_calendar/kb_sync_service.py | 65 +++- .../google_calendar/tool_metadata_service.py | 282 ++++++++++++-- .../google_drive/tool_metadata_service.py | 96 ++++- .../app/tasks/chat/stream_new_chat.py | 46 ++- .../google_calendar_indexer.py | 59 ++- .../google_drive_indexer.py | 356 +++++++++++++----- .../google_gmail_indexer.py | 146 ++++++- .../tasks/chat/test_tool_input_streaming.py | 56 ++- 25 files changed, 1756 insertions(+), 461 deletions(-) create mode 100644 surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py new file mode 100644 index 000000000..0ca1191a4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/composio_helpers.py @@ -0,0 +1,41 @@ +from typing import Any + +from app.db import SearchSourceConnector +from app.services.composio_service import ComposioService + + +def split_recipients(value: str | None) -> list[str]: + if not value: + return [] + return [recipient.strip() for recipient in value.split(",") if recipient.strip()] + + +def unwrap_composio_data(data: Any) -> Any: + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + return data + + +async def execute_composio_gmail_tool( + connector: SearchSourceConnector, + user_id: str, + tool_name: str, + params: dict[str, Any], +) -> tuple[Any, str | None]: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return None, "Composio connected account ID not found for this Gmail connector." + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Gmail error") + + return unwrap_composio_data(result.get("data")), None diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index 0bd044695..7e9ddf7d3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -157,16 +157,13 @@ def create_create_gmail_draft_tool( f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" ) - if ( + is_composio_gmail = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_gmail: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", @@ -208,10 +205,6 @@ def create_create_gmail_draft_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - message = MIMEText(final_body) message["to"] = final_to message["subject"] = final_subject @@ -222,15 +215,43 @@ def create_create_gmail_draft_tool( raw = base64.urlsafe_b64encode(message.as_bytes()).decode() try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .create(userId="me", body={"message": {"raw": raw}}) - .execute() - ), - ) + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, + ) + + created, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_CREATE_EMAIL_DRAFT", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(created, dict): + created = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .create(userId="me", body={"message": {"raw": raw}}) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py index deec1627c..1964181e4 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -50,6 +50,54 @@ def create_read_gmail_email_tool( "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _format_gmail_summary, + ) + from app.services.composio_service import ComposioService + + service = ComposioService() + detail, error = await service.get_gmail_message_detail( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if error: + return {"status": "error", "message": error} + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + summary = _format_gmail_summary(detail) + content = ( + f"# {summary['subject']}\n\n" + f"**From:** {summary['from']}\n" + f"**To:** {summary['to']}\n" + f"**Date:** {summary['date']}\n\n" + f"## Message Content\n\n" + f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {summary['message_id']}\n" + f"- **Thread ID:** {summary['thread_id']}\n" + ) + return { + "status": "success", + "message_id": summary["message_id"] or message_id, + "content": content, + } + from app.agents.new_chat.tools.gmail.search_emails import _build_credentials creds = _build_credentials(connector) diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index 2e363609e..59886159a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector): from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - raise ValueError("Composio connected account ID not found.") - return build_composio_credentials(cca_id) + raise ValueError("Composio connectors must use Composio tool execution.") from google.oauth2.credentials import Credentials @@ -67,6 +62,63 @@ def _build_credentials(connector: SearchSourceConnector): ) +def _gmail_headers(message: dict[str, Any]) -> dict[str, str]: + headers = message.get("payload", {}).get("headers", []) + return { + header.get("name", "").lower(): header.get("value", "") + for header in headers + if isinstance(header, dict) + } + + +def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]: + headers = _gmail_headers(message) + return { + "message_id": message.get("id") or message.get("messageId"), + "thread_id": message.get("threadId"), + "subject": message.get("subject") or headers.get("subject", "No Subject"), + "from": message.get("sender") or headers.get("from", "Unknown"), + "to": message.get("to") or headers.get("to", ""), + "date": message.get("messageTimestamp") or headers.get("date", ""), + "snippet": message.get("snippet") or message.get("messageText", "")[:300], + "labels": message.get("labelIds", []), + } + + +async def _search_composio_gmail( + connector: SearchSourceConnector, + user_id: str, + query: str, + max_results: int, +) -> dict[str, Any]: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.services.composio_service import ComposioService + + service = ComposioService() + messages, _next_token, _estimate, error = await service.get_gmail_messages( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=max_results, + ) + if error: + return {"status": "error", "message": error} + + emails = [_format_gmail_summary(message) for message in messages] + return { + "status": "success", + "emails": emails, + "total": len(emails), + "message": "No emails found." if not emails else None, + } + + def create_search_gmail_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, @@ -110,6 +162,14 @@ def create_search_gmail_tool( "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + return await _search_composio_gmail( + connector, str(user_id), query, max_results + ) + creds = _build_credentials(connector) from app.connectors.google_gmail_connector import GoogleGmailConnector diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index c3f0999f4..79ff2d9c7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -158,16 +158,13 @@ def create_send_gmail_email_tool( f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" ) - if ( + is_composio_gmail = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_gmail: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", @@ -209,10 +206,6 @@ def create_send_gmail_email_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - message = MIMEText(final_body) message["to"] = final_to message["subject"] = final_subject @@ -223,15 +216,43 @@ def create_send_gmail_email_tool( raw = base64.urlsafe_b64encode(message.as_bytes()).decode() try: - sent = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .send(userId="me", body={"raw": raw}) - .execute() - ), - ) + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, + ) + + sent, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_SEND_EMAIL", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(sent, dict): + sent = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + sent = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .send(userId="me", body={"raw": raw}) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index 1f1f6227a..4e710dc72 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -158,16 +158,13 @@ def create_trash_gmail_email_tool( f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" ) - if ( + is_composio_gmail = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_gmail: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", @@ -209,20 +206,33 @@ def create_trash_gmail_email_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .trash(userId="me", id=final_message_id) - .execute() - ), - ) + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + _trashed, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_MOVE_TO_TRASH", + {"user_id": "me", "message_id": final_message_id}, + ) + if error: + raise RuntimeError(error) + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .trash(userId="me", id=final_message_id) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 91178cd21..50956f03a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -188,16 +188,13 @@ def create_update_gmail_draft_tool( f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" ) - if ( + is_composio_gmail = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_gmail: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", @@ -239,18 +236,22 @@ def create_update_gmail_draft_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - # Resolve draft_id if not already available if not final_draft_id: logger.info( f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" ) - final_draft_id = await _find_draft_id_by_message( - gmail_service, message_id - ) + if is_composio_gmail: + final_draft_id = await _find_composio_draft_id_by_message( + connector, user_id, message_id + ) + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + final_draft_id = await _find_draft_id_by_message( + gmail_service, message_id + ) if not final_draft_id: return { @@ -272,19 +273,48 @@ def create_update_gmail_draft_tool( raw = base64.urlsafe_b64encode(message.as_bytes()).decode() try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .update( - userId="me", - id=final_draft_id, - body={"message": {"raw": raw}}, - ) - .execute() - ), - ) + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, + ) + + updated, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_UPDATE_DRAFT", + { + "user_id": "me", + "draft_id": final_draft_id, + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(updated, dict): + updated = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .update( + userId="me", + id=final_draft_id, + body={"message": {"raw": raw}}, + ) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError @@ -408,3 +438,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str except Exception as e: logger.warning(f"Failed to look up draft by message_id: {e}") return None + + +async def _find_composio_draft_id_by_message( + connector: Any, user_id: str, message_id: str +) -> str | None: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + page_token = "" + while True: + params: dict[str, Any] = { + "user_id": "me", + "max_results": 100, + "verbose": False, + } + if page_token: + params["page_token"] = page_token + + data, error = await execute_composio_gmail_tool( + connector, user_id, "GMAIL_LIST_DRAFTS", params + ) + if error or not isinstance(data, dict): + return None + + for draft in data.get("drafts", []): + if draft.get("message", {}).get("id") == message_id: + return draft.get("id") + + page_token = data.get("nextPageToken") or data.get("next_page_token") or "" + if not page_token: + return None diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py index 37bcf083e..0a4720f6f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py @@ -168,16 +168,13 @@ def create_create_calendar_event_tool( f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" ) - if ( + is_composio_calendar = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_calendar: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this connector.", @@ -211,10 +208,6 @@ def create_create_calendar_event_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - tz = context.get("timezone", "UTC") event_body: dict[str, Any] = { "summary": final_summary, @@ -231,14 +224,51 @@ def create_create_calendar_event_tool( ] try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .insert(calendarId="primary", body=event_body) - .execute() - ), - ) + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params = { + "calendar_id": "primary", + "summary": final_summary, + "start_datetime": final_start_datetime, + "end_datetime": final_end_datetime, + "timezone": tz, + "attendees": final_attendees or [], + } + if final_description: + composio_params["description"] = final_description + if final_location: + composio_params["location"] = final_location + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_CREATE_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + created = composio_result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .insert(calendarId="primary", body=event_body) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py index 4d9d69b4b..53596ac0f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py @@ -159,16 +159,13 @@ def create_delete_calendar_event_tool( f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" ) - if ( + is_composio_calendar = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_calendar: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this connector.", @@ -202,19 +199,34 @@ def create_delete_calendar_event_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .delete(calendarId="primary", eventId=final_event_id) - .execute() - ), - ) + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_DELETE_EVENT", + params={"calendar_id": "primary", "event_id": final_event_id}, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .delete(calendarId="primary", eventId=final_event_id) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index dc6adb822..b5194d15f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -16,6 +16,35 @@ _CALENDAR_TYPES = [ ] +def _to_calendar_boundary(value: str, *, is_end: bool) -> str: + if "T" in value: + return value + time = "23:59:59" if is_end else "00:00:00" + return f"{value}T{time}Z" + + +def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + events = [] + for ev in events_raw: + start = ev.get("start", {}) + end = ev.get("end", {}) + attendees_raw = ev.get("attendees", []) + events.append( + { + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [a.get("email", "") for a in attendees_raw[:10]], + "status": ev.get("status", ""), + } + ) + return events + + def create_search_calendar_events_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, @@ -61,22 +90,47 @@ def create_search_calendar_events_tool( "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", } - creds = _build_credentials(connector) + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } - from app.connectors.google_calendar_connector import GoogleCalendarConnector + from app.services.composio_service import ComposioService - cal = GoogleCalendarConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) + events_raw, error = await ComposioService().get_calendar_events( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + time_min=_to_calendar_boundary(start_date, is_end=False), + time_max=_to_calendar_boundary(end_date, is_end=True), + max_results=max_results, + ) + if not events_raw and not error: + error = "No events found in the specified date range." + else: + creds = _build_credentials(connector) - events_raw, error = await cal.get_all_primary_calendar_events( - start_date=start_date, - end_date=end_date, - max_results=max_results, - ) + from app.connectors.google_calendar_connector import ( + GoogleCalendarConnector, + ) + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) if error: if ( @@ -97,24 +151,7 @@ def create_search_calendar_events_tool( } return {"status": "error", "message": error} - events = [] - for ev in events_raw: - start = ev.get("start", {}) - end = ev.get("end", {}) - attendees_raw = ev.get("attendees", []) - events.append( - { - "event_id": ev.get("id"), - "summary": ev.get("summary", "No Title"), - "start": start.get("dateTime") or start.get("date", ""), - "end": end.get("dateTime") or end.get("date", ""), - "location": ev.get("location", ""), - "description": ev.get("description", ""), - "html_link": ev.get("htmlLink", ""), - "attendees": [a.get("email", "") for a in attendees_raw[:10]], - "status": ev.get("status", ""), - } - ) + events = _format_calendar_events(events_raw) return {"status": "success", "events": events, "total": len(events)} diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 259f52bba..1dba36c20 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -192,16 +192,13 @@ def create_update_calendar_event_tool( f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" ) - if ( + is_composio_calendar = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_calendar: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this connector.", @@ -235,10 +232,6 @@ def create_update_calendar_event_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - update_body: dict[str, Any] = {} if final_new_summary is not None: update_body["summary"] = final_new_summary @@ -264,18 +257,65 @@ def create_update_calendar_event_tool( } try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .patch( - calendarId="primary", - eventId=final_event_id, - body=update_body, + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params: dict[str, Any] = { + "calendar_id": "primary", + "event_id": final_event_id, + } + if final_new_summary is not None: + composio_params["summary"] = final_new_summary + if final_new_start_datetime is not None: + composio_params["start_time"] = final_new_start_datetime + if final_new_end_datetime is not None: + composio_params["end_time"] = final_new_end_datetime + if final_new_description is not None: + composio_params["description"] = final_new_description + if final_new_location is not None: + composio_params["location"] = final_new_location + if final_new_attendees is not None: + composio_params["attendees"] = [ + e.strip() for e in final_new_attendees if e.strip() + ] + if not _is_date_only( + final_new_start_datetime or final_new_end_datetime or "" + ): + composio_params["timezone"] = context.get("timezone", "UTC") + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_PATCH_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) ) - .execute() - ), - ) + updated = composio_result.get("data", {}) + if isinstance(updated, dict): + updated = updated.get("data", updated) + if isinstance(updated, dict): + updated = updated.get("response_data", updated) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .patch( + calendarId="primary", + eventId=final_event_id, + body=update_body, + ) + .execute() + ), + ) except Exception as api_err: from googleapiclient.errors import HttpError diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py index f36db8f3f..2becec100 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py @@ -179,29 +179,59 @@ def create_create_google_drive_file_tool( f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" ) - pre_built_creds = None - if ( + is_composio_drive = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_drive: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) - + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } client = GoogleDriveClient( session=db_session, connector_id=actual_connector_id, - credentials=pre_built_creds, ) try: - created = await client.create_file( - name=final_name, - mime_type=mime_type, - parent_folder_id=final_parent_folder_id, - content=final_content, - ) + if is_composio_drive: + from app.services.composio_service import ComposioService + + params: dict[str, Any] = { + "name": final_name, + "mimeType": mime_type, + "fields": "id,name,webViewLink,mimeType", + } + if final_parent_folder_id: + params["parents"] = [final_parent_folder_id] + if final_content: + params["description"] = final_content[:4096] + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_CREATE_FILE", + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + created = result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + if not isinstance(created, dict): + created = {} + else: + created = await client.create_file( + name=final_name, + mime_type=mime_type, + parent_folder_id=final_parent_folder_id, + content=final_content, + ) except HttpError as http_err: if http_err.resp.status == 403: logger.warning( diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py index 832afff0d..3c404527e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py @@ -158,24 +158,38 @@ def create_delete_google_drive_file_tool( f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" ) - pre_built_creds = None - if ( + is_composio_drive = ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - from app.utils.google_credentials import build_composio_credentials - + ) + if is_composio_drive: cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } client = GoogleDriveClient( session=db_session, connector_id=connector.id, - credentials=pre_built_creds, ) try: - await client.trash_file(file_id=final_file_id) + if is_composio_drive: + from app.services.composio_service import ComposioService + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_TRASH_FILE", + params={"file_id": final_file_id}, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + else: + await client.trash_file(file_id=final_file_id) except HttpError as http_err: if http_err.resp.status == 403: logger.warning( diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 92248c2c9..5b64929de 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( { "create_gmail_draft", "update_gmail_draft", + "create_calendar_event", "create_notion_page", "create_confluence_page", "create_google_drive_file", diff --git a/surfsense_backend/app/routes/composio_routes.py b/surfsense_backend/app/routes/composio_routes.py index 4bf360365..7bc2addf8 100644 --- a/surfsense_backend/app/routes/composio_routes.py +++ b/surfsense_backend/app/routes/composio_routes.py @@ -649,13 +649,9 @@ async def list_composio_drive_folders( """ List folders AND files in user's Google Drive via Composio. - Uses the same GoogleDriveClient / list_folder_contents path as the native - connector, with Composio-sourced credentials. This means auth errors - propagate identically (Google returns 401 → exception → auth_expired flag). + Uses Composio's Google Drive tool execution path so managed OAuth tokens + do not need to be exposed through connected account state. """ - from app.connectors.google_drive import GoogleDriveClient, list_folder_contents - from app.utils.google_credentials import build_composio_credentials - if not ComposioService.is_enabled(): raise HTTPException( status_code=503, @@ -689,10 +685,37 @@ async def list_composio_drive_folders( detail="Composio connected account not found. Please reconnect the connector.", ) - credentials = build_composio_credentials(composio_connected_account_id) - drive_client = GoogleDriveClient(session, connector_id, credentials=credentials) + service = ComposioService() + entity_id = f"surfsense_{user.id}" + items = [] + page_token = None + error = None - items, error = await list_folder_contents(drive_client, parent_id=parent_id) + while True: + page_items, next_token, page_error = await service.get_drive_files( + connected_account_id=composio_connected_account_id, + entity_id=entity_id, + folder_id=parent_id, + page_token=page_token, + page_size=100, + ) + if page_error: + error = page_error + break + + items.extend(page_items) + if not next_token: + break + page_token = next_token + + for item in items: + item["isFolder"] = ( + item.get("mimeType") == "application/vnd.google-apps.folder" + ) + + items.sort( + key=lambda item: (not item["isFolder"], item.get("name", "").lower()) + ) if error: error_lower = error.lower() diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index a8abe4aa8..edfab1d15 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -408,12 +408,37 @@ class ComposioService: files = [] next_token = None if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) # Try direct access first, then nested - files = data.get("files", []) or data.get("data", {}).get("files", []) + files = ( + data.get("files", []) + or ( + inner_data.get("files", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("files", []) + ) next_token = ( data.get("nextPageToken") or data.get("next_page_token") - or data.get("data", {}).get("nextPageToken") + or ( + inner_data.get("nextPageToken") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("next_page_token") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("nextPageToken") + or response_data.get("next_page_token") ) elif isinstance(data, list): files = data @@ -819,24 +844,61 @@ class ComposioService: next_token = None result_size_estimate = None if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) messages = ( data.get("messages", []) - or data.get("data", {}).get("messages", []) + or ( + inner_data.get("messages", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("messages", []) or data.get("emails", []) + or ( + inner_data.get("emails", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("emails", []) ) # Check for pagination token in various possible locations next_token = ( data.get("nextPageToken") or data.get("next_page_token") - or data.get("data", {}).get("nextPageToken") - or data.get("data", {}).get("next_page_token") + or ( + inner_data.get("nextPageToken") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("next_page_token") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("nextPageToken") + or response_data.get("next_page_token") ) # Extract resultSizeEstimate if available (Gmail API provides this) result_size_estimate = ( data.get("resultSizeEstimate") or data.get("result_size_estimate") - or data.get("data", {}).get("resultSizeEstimate") - or data.get("data", {}).get("result_size_estimate") + or ( + inner_data.get("resultSizeEstimate") + if isinstance(inner_data, dict) + else None + ) + or ( + inner_data.get("result_size_estimate") + if isinstance(inner_data, dict) + else None + ) + or response_data.get("resultSizeEstimate") + or response_data.get("result_size_estimate") ) elif isinstance(data, list): messages = data @@ -864,7 +926,7 @@ class ComposioService: try: result = await self.execute_tool( connected_account_id=connected_account_id, - tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID", + tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID", params={"message_id": message_id}, # snake_case entity_id=entity_id, ) @@ -872,7 +934,13 @@ class ComposioService: if not result.get("success"): return None, result.get("error", "Unknown error") - return result.get("data"), None + data = result.get("data") + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + return inner_data.get("response_data", inner_data), None + + return data, None except Exception as e: logger.error(f"Failed to get Gmail message detail: {e!s}") @@ -928,10 +996,27 @@ class ComposioService: # Try different possible response structures events = [] if isinstance(data, dict): + inner_data = data.get("data", data) + response_data = ( + inner_data.get("response_data", {}) + if isinstance(inner_data, dict) + else {} + ) events = ( data.get("items", []) - or data.get("data", {}).get("items", []) + or ( + inner_data.get("items", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("items", []) or data.get("events", []) + or ( + inner_data.get("events", []) + if isinstance(inner_data, dict) + else [] + ) + or response_data.get("events", []) ) elif isinstance(data, list): events = data diff --git a/surfsense_backend/app/services/gmail/tool_metadata_service.py b/surfsense_backend/app/services/gmail/tool_metadata_service.py index c903e24af..4855c1cc9 100644 --- a/surfsense_backend/app/services/gmail/tool_metadata_service.py +++ b/surfsense_backend/app/services/gmail/tool_metadata_service.py @@ -17,7 +17,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -78,14 +78,49 @@ class GmailToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session - async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: - if ( + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + return cca_id + + def _unwrap_composio_data(self, data: Any) -> Any: + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + return data + + async def _execute_composio_gmail_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict[str, Any], + ) -> tuple[Any, str | None]: + result = await ComposioService().execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Gmail error") + return self._unwrap_composio_data(result.get("data")), None + + async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: + if self._is_composio_connector(connector): + raise ValueError( + "Composio Gmail connectors must use Composio tool execution" + ) config_data = dict(connector.config) @@ -139,6 +174,12 @@ class GmailToolMetadataService: if not connector: return True + if self._is_composio_connector(connector): + _profile, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_GET_PROFILE", {"user_id": "me"} + ) + return bool(error) + creds = await self._build_credentials(connector) service = build("gmail", "v1", credentials=creds) await asyncio.get_event_loop().run_in_executor( @@ -221,14 +262,21 @@ class GmailToolMetadataService: ) connector = result.scalar_one_or_none() if connector: - creds = await self._build_credentials(connector) - service = build("gmail", "v1", credentials=creds) - profile = await asyncio.get_event_loop().run_in_executor( - None, - lambda service=service: ( - service.users().getProfile(userId="me").execute() - ), - ) + if self._is_composio_connector(connector): + profile, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_GET_PROFILE", {"user_id": "me"} + ) + if error: + raise RuntimeError(error) + else: + creds = await self._build_credentials(connector) + service = build("gmail", "v1", credentials=creds) + profile = await asyncio.get_event_loop().run_in_executor( + None, + lambda service=service: ( + service.users().getProfile(userId="me").execute() + ), + ) acc_dict["email"] = profile.get("emailAddress", "") except Exception: logger.warning( @@ -298,6 +346,23 @@ class GmailToolMetadataService: Returns ``None`` on any failure so callers can degrade gracefully. """ try: + if self._is_composio_connector(connector): + if not draft_id: + draft_id = await self._find_composio_draft_id(connector, message_id) + if not draft_id: + return None + + draft, error = await self._execute_composio_gmail_tool( + connector, + "GMAIL_GET_DRAFT", + {"user_id": "me", "draft_id": draft_id, "format": "full"}, + ) + if error or not isinstance(draft, dict): + return None + + payload = draft.get("message", {}).get("payload", {}) + return self._extract_body_from_payload(payload) + creds = await self._build_credentials(connector) service = build("gmail", "v1", credentials=creds) @@ -326,6 +391,33 @@ class GmailToolMetadataService: ) return None + async def _find_composio_draft_id( + self, connector: SearchSourceConnector, message_id: str + ) -> str | None: + page_token = "" + while True: + params: dict[str, Any] = { + "user_id": "me", + "max_results": 100, + "verbose": False, + } + if page_token: + params["page_token"] = page_token + + data, error = await self._execute_composio_gmail_tool( + connector, "GMAIL_LIST_DRAFTS", params + ) + if error or not isinstance(data, dict): + return None + + for draft in data.get("drafts", []): + if draft.get("message", {}).get("id") == message_id: + return draft.get("id") + + page_token = data.get("nextPageToken") or data.get("next_page_token") or "" + if not page_token: + return None + async def _find_draft_id(self, service: Any, message_id: str) -> str | None: """Resolve a draft ID from its message ID by scanning drafts.list.""" try: diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 20426f3bc..602a55738 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -14,6 +14,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) +from app.services.composio_service import ComposioService from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -21,7 +22,6 @@ from app.utils.document_converters import ( generate_document_summary, generate_unique_identifier_hash, ) -from app.utils.google_credentials import build_composio_credentials logger = logging.getLogger(__name__) @@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService: logger.warning("Document %s not found in KB", document_id) return {"status": "not_indexed"} - creds = await self._build_credentials_for_connector(connector_id) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - calendar_id = (document.document_metadata or {}).get( "calendar_id" ) or "primary" - live_event = await loop.run_in_executor( - None, - lambda: ( - service.events() - .get(calendarId=calendar_id, eventId=event_id) - .execute() - ), - ) + connector = await self._get_connector(connector_id) + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_EVENTS_GET", + params={"calendar_id": calendar_id, "event_id": event_id}, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get("error", "Unknown Composio Calendar error") + ) + live_event = composio_result.get("data", {}) + if isinstance(live_event, dict): + live_event = live_event.get("data", live_event) + if isinstance(live_event, dict): + live_event = live_event.get("response_data", live_event) + else: + creds = await self._build_credentials_for_connector(connector_id) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + live_event = await loop.run_in_executor( + None, + lambda: ( + service.events() + .get(calendarId=calendar_id, eventId=event_id) + .execute() + ), + ) event_summary = live_event.get("summary", "") description = live_event.get("description", "") @@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService: await self.db_session.rollback() return {"status": "error", "message": str(e)} - async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: + async def _get_connector(self, connector_id: int) -> SearchSourceConnector: result = await self.db_session.execute( select(SearchSourceConnector).where( SearchSourceConnector.id == connector_id @@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService: connector = result.scalar_one_or_none() if not connector: raise ValueError(f"Connector {connector_id} not found") + return connector + async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: + connector = await self._get_connector(connector_id) if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) - raise ValueError("Composio connected_account_id not found") + raise ValueError( + "Composio Calendar connectors must use Composio tool execution" + ) config_data = dict(connector.config) diff --git a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py index c7bfe1d50..7e50ab039 100644 --- a/surfsense_backend/app/services/google_calendar/tool_metadata_service.py +++ b/surfsense_backend/app/services/google_calendar/tool_metadata_service.py @@ -16,7 +16,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session - async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: - if ( + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - return build_composio_credentials(cca_id) + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: raise ValueError("Composio connected_account_id not found") + return cca_id + + async def _execute_composio_calendar_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict, + ) -> tuple[dict | list | None, str | None]: + service = ComposioService() + result = await service.execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Calendar error") + + data = result.get("data") + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner), None + return inner, None + return data, None + + async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: + if self._is_composio_connector(connector): + raise ValueError( + "Composio Calendar connectors must use Composio tool execution" + ) config_data = dict(connector.config) @@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService: if not connector: return True + if self._is_composio_connector(connector): + _data, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_GET_CALENDAR", + {"calendar_id": "primary"}, + ) + return bool(error) + creds = await self._build_credentials(connector) loop = asyncio.get_event_loop() await loop.run_in_executor( @@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService: timezone_str = "" if connector: try: - creds = await self._build_credentials(connector) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) + if self._is_composio_connector(connector): + cal_list, cal_error = await self._execute_composio_calendar_tool( + connector, "GOOGLECALENDAR_LIST_CALENDARS", {} + ) + if cal_error: + raise RuntimeError(cal_error) + ( + settings, + settings_error, + ) = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_SETTINGS_GET", + {"setting": "timezone"}, + ) + if not settings_error and isinstance(settings, dict): + timezone_str = settings.get("value", "") + else: + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) - cal_list = await loop.run_in_executor( - None, lambda: service.calendarList().list().execute() - ) - for cal in cal_list.get("items", []): + cal_list = await loop.run_in_executor( + None, lambda: service.calendarList().list().execute() + ) + + tz_setting = await loop.run_in_executor( + None, + lambda: service.settings().get(setting="timezone").execute(), + ) + timezone_str = tz_setting.get("value", "") + + calendar_items = [] + if isinstance(cal_list, dict): + calendar_items = ( + cal_list.get("items") or cal_list.get("calendars") or [] + ) + elif isinstance(cal_list, list): + calendar_items = cal_list + + for cal in calendar_items: calendars.append( { "id": cal.get("id", ""), @@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService: "primary": cal.get("primary", False), } ) - - tz_setting = await loop.run_in_executor( - None, - lambda: service.settings().get(setting="timezone").execute(), - ) - timezone_str = tz_setting.get("value", "") except Exception: logger.warning( "Failed to fetch calendars/timezone for connector %s", @@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService: event_dict = event.to_dict() try: - creds = await self._build_credentials(connector) - loop = asyncio.get_event_loop() - service = await loop.run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) calendar_id = event.calendar_id or "primary" - live_event = await loop.run_in_executor( - None, - lambda: ( - service.events() - .get(calendarId=calendar_id, eventId=event.event_id) - .execute() - ), - ) + if self._is_composio_connector(connector): + live_event, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_EVENTS_GET", + {"calendar_id": calendar_id, "event_id": event.event_id}, + ) + if error: + raise RuntimeError(error) + else: + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + live_event = await loop.run_in_executor( + None, + lambda: ( + service.events() + .get(calendarId=calendar_id, eventId=event.event_id) + .execute() + ), + ) event_dict["summary"] = live_event.get("summary", event_dict["summary"]) event_dict["description"] = live_event.get( @@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService: ) -> dict: resolved = await self._resolve_event(search_space_id, user_id, event_ref) if not resolved: + live_resolved = await self._resolve_live_event( + search_space_id, user_id, event_ref + ) + if not live_resolved: + return { + "error": ( + f"Event '{event_ref}' not found in your indexed or live Google Calendar events. " + "This could mean: (1) the event doesn't exist, " + "(2) the event name is different, or " + "(3) the connected calendar account cannot access it." + ) + } + + connector, live_event = live_resolved + account = GoogleCalendarAccount.from_connector(connector) + acc_dict = account.to_dict() + auth_expired = await self._check_account_health(connector.id) + acc_dict["auth_expired"] = auth_expired + if auth_expired: + await self._persist_auth_expired(connector.id) + return { - "error": ( - f"Event '{event_ref}' not found in your indexed Google Calendar events. " - "This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the event name is different." - ) + "account": acc_dict, + "event": self._event_dict_from_live_event(live_event), } document, connector = resolved @@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService: if row: return row[0], row[1] return None + + async def _resolve_live_event( + self, search_space_id: int, user_id: str, event_ref: str + ) -> tuple[SearchSourceConnector, dict] | None: + result = await self._db_session.execute( + select(SearchSourceConnector) + .filter( + and_( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES), + ) + ) + .order_by(SearchSourceConnector.last_indexed_at.desc()) + ) + connectors = result.scalars().all() + + for connector in connectors: + try: + events = await self._search_live_events(connector, event_ref) + except Exception: + logger.warning( + "Failed to search live calendar events for connector %s", + connector.id, + exc_info=True, + ) + continue + + if not events: + continue + + normalized_ref = event_ref.strip().lower() + exact_match = next( + ( + event + for event in events + if event.get("summary", "").strip().lower() == normalized_ref + ), + None, + ) + return connector, exact_match or events[0] + + return None + + async def _search_live_events( + self, connector: SearchSourceConnector, event_ref: str + ) -> list[dict]: + if self._is_composio_connector(connector): + data, error = await self._execute_composio_calendar_tool( + connector, + "GOOGLECALENDAR_EVENTS_LIST", + { + "calendar_id": "primary", + "q": event_ref, + "max_results": 10, + "single_events": True, + "order_by": "startTime", + }, + ) + if error: + raise RuntimeError(error) + if isinstance(data, dict): + return data.get("items") or data.get("events") or [] + return data if isinstance(data, list) else [] + + creds = await self._build_credentials(connector) + loop = asyncio.get_event_loop() + service = await loop.run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + response = await loop.run_in_executor( + None, + lambda: ( + service.events() + .list( + calendarId="primary", + q=event_ref, + maxResults=10, + singleEvents=True, + orderBy="startTime", + ) + .execute() + ), + ) + return response.get("items", []) + + def _event_dict_from_live_event(self, event: dict) -> dict: + start_data = event.get("start", {}) + end_data = event.get("end", {}) + return { + "event_id": event.get("id", ""), + "summary": event.get("summary", "No Title"), + "start": start_data.get("dateTime", start_data.get("date", "")), + "end": end_data.get("dateTime", end_data.get("date", "")), + "description": event.get("description", ""), + "location": event.get("location", ""), + "attendees": [ + { + "email": attendee.get("email", ""), + "responseStatus": attendee.get("responseStatus", ""), + } + for attendee in event.get("attendees", []) + ], + "calendar_id": event.get("calendarId", "primary"), + "document_id": None, + "indexed_at": None, + } diff --git a/surfsense_backend/app/services/google_drive/tool_metadata_service.py b/surfsense_backend/app/services/google_drive/tool_metadata_service.py index 221bee14a..0f654bc78 100644 --- a/surfsense_backend/app/services/google_drive/tool_metadata_service.py +++ b/surfsense_backend/app/services/google_drive/tool_metadata_service.py @@ -13,7 +13,7 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.utils.google_credentials import build_composio_credentials +from app.services.composio_service import ComposioService logger = logging.getLogger(__name__) @@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService: def __init__(self, db_session: AsyncSession): self._db_session = db_session + def _is_composio_connector(self, connector: SearchSourceConnector) -> bool: + return ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + + def _get_composio_connected_account_id( + self, connector: SearchSourceConnector + ) -> str: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected_account_id not found") + return cca_id + + async def _execute_composio_drive_tool( + self, + connector: SearchSourceConnector, + tool_name: str, + params: dict, + ) -> tuple[dict | list | None, str | None]: + result = await ComposioService().execute_tool( + connected_account_id=self._get_composio_connected_account_id(connector), + tool_name=tool_name, + params=params, + entity_id=f"surfsense_{connector.user_id}", + ) + if not result.get("success"): + return None, result.get("error", "Unknown Composio Drive error") + data = result.get("data") + if isinstance(data, dict): + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner), None + return inner, None + return data, None + async def get_creation_context(self, search_space_id: int, user_id: str) -> dict: accounts = await self._get_google_drive_accounts(search_space_id, user_id) @@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService: if not connector: return True - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if self._is_composio_connector(connector): + _data, error = await self._execute_composio_drive_tool( + connector, + "GOOGLEDRIVE_LIST_FILES", + { + "q": "trashed = false", + "page_size": 1, + "fields": "files(id)", + }, + ) + return bool(error) client = GoogleDriveClient( session=self._db_session, connector_id=connector_id, - credentials=pre_built_creds, ) await client.list_files( query="trashed = false", page_size=1, fields="files(id)" @@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService: parent_folders[connector_id] = [] continue - pre_built_creds = None - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if self._is_composio_connector(connector): + data, error = await self._execute_composio_drive_tool( + connector, + "GOOGLEDRIVE_LIST_FILES", + { + "q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents", + "fields": "files(id,name)", + "page_size": 50, + }, + ) + if error: + logger.warning( + "Failed to list folders for connector %s: %s", + connector_id, + error, + ) + parent_folders[connector_id] = [] + continue + folders = [] + if isinstance(data, dict): + folders = data.get("files", []) + elif isinstance(data, list): + folders = data + parent_folders[connector_id] = [ + {"folder_id": f["id"], "name": f["name"]} + for f in folders + if f.get("id") and f.get("name") + ] + continue client = GoogleDriveClient( session=self._db_session, connector_id=connector_id, - credentials=pre_built_creds, ) folders, _, error = await client.list_files( diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c6ac3311a..5eb35f8b1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -96,6 +96,46 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int: return min(delay, TURN_CANCELLING_MAX_DELAY_MS) +def _first_interrupt_value(state: Any) -> dict[str, Any] | None: + """Return the first LangGraph interrupt payload across all snapshot tasks.""" + def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None: + if isinstance(candidate, dict): + value = candidate.get("value", candidate) + return value if isinstance(value, dict) else None + value = getattr(candidate, "value", None) + if isinstance(value, dict): + return value + if isinstance(candidate, (list, tuple)): + for item in candidate: + extracted = _extract_interrupt_value(item) + if extracted is not None: + return extracted + return None + + for task in getattr(state, "tasks", ()) or (): + try: + interrupts = getattr(task, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + interrupts = () + if not interrupts: + extracted = _extract_interrupt_value(task) + if extracted is not None: + return extracted + continue + for interrupt_item in interrupts: + extracted = _extract_interrupt_value(interrupt_item) + if extracted is not None: + return extracted + try: + state_interrupts = getattr(state, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + state_interrupts = () + extracted = _extract_interrupt_value(state_interrupts) + if extracted is not None: + return extracted + return None + + def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. @@ -2178,10 +2218,10 @@ async def _stream_agent_events( result.agent_called_update_memory = called_update_memory _log_file_contract("turn_outcome", result) - is_interrupted = state.tasks and any(task.interrupts for task in state.tasks) - if is_interrupted: + interrupt_value = _first_interrupt_value(state) + if interrupt_value is not None: result.is_interrupted = True - result.interrupt_value = state.tasks[0].interrupts[0].value + result.interrupt_value = interrupt_value yield streaming_service.format_interrupt_request(result.interrupt_value) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index 6912ffe5a..3c9f27303 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from .base import ( check_duplicate_document_by_hash, @@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]] HEARTBEAT_INTERVAL_SECONDS = 30 +def _format_calendar_event_to_markdown(event: dict) -> str: + return GoogleCalendarConnector.format_event_to_markdown(None, event) + + def _build_connector_doc( event: dict, event_markdown: str, @@ -150,7 +152,14 @@ async def index_google_calendar_events( ) return 0, 0, f"Connector with ID {connector_id} not found" - # ── Credential building ─────────────────────────────────────── + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) + calendar_client = None + composio_service = None + connected_account_id = None + + # ── Credential/client building ──────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -161,7 +170,7 @@ async def index_google_calendar_events( {"error_type": "MissingComposioAccount"}, ) return 0, 0, "Composio connected_account_id not found" - credentials = build_composio_credentials(connected_account_id) + composio_service = ComposioService() else: config_data = connector.config @@ -229,12 +238,13 @@ async def index_google_calendar_events( {"stage": "client_initialization"}, ) - calendar_client = GoogleCalendarConnector( - credentials=credentials, - session=session, - user_id=user_id, - connector_id=connector_id, - ) + if not is_composio_connector: + calendar_client = GoogleCalendarConnector( + credentials=credentials, + session=session, + user_id=user_id, + connector_id=connector_id, + ) # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": @@ -300,9 +310,26 @@ async def index_google_calendar_events( ) try: - events, error = await calendar_client.get_all_primary_calendar_events( - start_date=start_date_str, end_date=end_date_str - ) + if is_composio_connector: + start_dt = parse_date_flexible(start_date_str).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + end_dt = parse_date_flexible(end_date_str).replace( + hour=23, minute=59, second=59, microsecond=0 + ) + events, error = await composio_service.get_calendar_events( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + time_min=start_dt.isoformat(), + time_max=end_dt.isoformat(), + max_results=250, + ) + if not events and not error: + error = "No events found in the specified date range." + else: + events, error = await calendar_client.get_all_primary_calendar_events( + start_date=start_date_str, end_date=end_date_str + ) if error: if "No events found" in error: @@ -381,7 +408,7 @@ async def index_google_calendar_events( documents_skipped += 1 continue - event_markdown = calendar_client.format_event_to_markdown(event) + event_markdown = _format_calendar_event_to_markdown(event) if not event_markdown.strip(): logger.warning(f"Skipping event with no content: {event_summary}") documents_skipped += 1 diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 21cdbd29f..686f13d9e 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -9,6 +9,8 @@ import asyncio import logging import time from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any from sqlalchemy import String, cast, select from sqlalchemy.exc import SQLAlchemyError @@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService @@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import ( get_connector_by_id, update_connector_last_indexed, ) -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES ACCEPTED_DRIVE_CONNECTOR_TYPES = { SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, @@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30 logger = logging.getLogger(__name__) +class ComposioDriveClient: + """Google Drive client facade backed by Composio tool execution. + + Composio-managed OAuth connections can execute tools without exposing raw + OAuth tokens through connected account state. + """ + + def __init__( + self, + session: AsyncSession, + connector_id: int, + connected_account_id: str, + entity_id: str, + ): + self.session = session + self.connector_id = connector_id + self.connected_account_id = connected_account_id + self.entity_id = entity_id + self.composio = ComposioService() + + async def list_files( + self, + query: str = "", + fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)", + page_size: int = 100, + page_token: str | None = None, + ) -> tuple[list[dict[str, Any]], str | None, str | None]: + params: dict[str, Any] = { + "page_size": min(page_size, 100), + "fields": fields, + } + if query: + params["q"] = query + if page_token: + params["page_token"] = page_token + + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_LIST_FILES", + params=params, + entity_id=self.entity_id, + ) + if not result.get("success"): + return [], None, result.get("error", "Unknown error") + + data = result.get("data", {}) + files = [] + next_token = None + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + files = inner_data.get("files", []) + next_token = inner_data.get("nextPageToken") or inner_data.get( + "next_page_token" + ) + elif isinstance(data, list): + files = data + + return files, next_token, None + + async def get_file_metadata( + self, file_id: str, fields: str = "*" + ) -> tuple[dict[str, Any] | None, str | None]: + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_GET_FILE_METADATA", + params={"file_id": file_id, "fields": fields}, + entity_id=self.entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + data = result.get("data", {}) + if isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + return inner_data, None + + return None, "Could not extract metadata from Composio response" + + async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]: + return await self._download_file_content(file_id) + + async def download_file_to_disk( + self, + file_id: str, + dest_path: str, + chunksize: int = 5 * 1024 * 1024, + ) -> str | None: + del chunksize + content, error = await self.download_file(file_id) + if error: + return error + if content is None: + return "No content returned from Composio" + Path(dest_path).write_bytes(content) + return None + + async def export_google_file( + self, file_id: str, mime_type: str + ) -> tuple[bytes | None, str | None]: + return await self._download_file_content(file_id, mime_type=mime_type) + + async def _download_file_content( + self, file_id: str, mime_type: str | None = None + ) -> tuple[bytes | None, str | None]: + params: dict[str, Any] = {"file_id": file_id} + if mime_type: + params["mime_type"] = mime_type + + result = await self.composio.execute_tool( + connected_account_id=self.connected_account_id, + tool_name="GOOGLEDRIVE_DOWNLOAD_FILE", + params=params, + entity_id=self.entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + return self._read_download_result(result.get("data")) + + def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]: + if isinstance(data, bytes): + return data, None + + file_path: str | None = None + if isinstance(data, str): + file_path = data + elif isinstance(data, dict): + inner_data = data.get("data", data) + if isinstance(inner_data, dict): + for key in ("file_path", "downloaded_file_content", "path", "uri"): + value = inner_data.get(key) + if isinstance(value, str): + file_path = value + break + if isinstance(value, dict): + nested = ( + value.get("file_path") + or value.get("downloaded_file_content") + or value.get("path") + or value.get("uri") + or value.get("s3url") + ) + if isinstance(nested, str): + file_path = nested + break + + if not file_path: + return None, "No file path/content returned from Composio" + + if file_path.startswith(("http://", "https://")): + try: + import urllib.request + + with urllib.request.urlopen(file_path, timeout=60) as response: + return response.read(), None + except Exception as e: + return None, f"Failed to download Composio file URL: {e!s}" + + path_obj = Path(file_path) + if path_obj.is_absolute() or ".composio" in str(path_obj): + if not path_obj.exists(): + return None, f"File not found at path: {file_path}" + return path_obj.read_bytes(), None + + try: + import base64 + + return base64.b64decode(file_path), None + except Exception: + return file_path.encode("utf-8"), None + + +def _build_drive_client_for_connector( + session: AsyncSession, + connector_id: int, + connector: object, + user_id: str, +) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]: + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + connected_account_id = connector.config.get("composio_connected_account_id") + if not connected_account_id: + return None, ( + f"Composio connected_account_id not found for connector {connector_id}" + ) + return ( + ComposioDriveClient( + session, + connector_id, + connected_account_id, + entity_id=f"surfsense_{user_id}", + ), + None, + ) + + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and not config.SECRET_KEY: + return None, "SECRET_KEY not configured but credentials are marked as encrypted" + + return GoogleDriveClient(session, connector_id), None + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -927,34 +1130,17 @@ async def index_google_drive_files( {"stage": "client_initialization"}, ) - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, 0, error_msg, 0 - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - "SECRET_KEY not configured but credentials are encrypted", - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - 0, - ) + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + await task_logger.log_task_failure( + log_entry, + client_error or "Failed to initialize Google Drive client", + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, 0, client_error, 0 connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -963,10 +1149,6 @@ async def index_google_drive_files( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - if not folder_id: error_msg = "folder_id is required for Google Drive indexing" await task_logger.log_task_failure( @@ -979,8 +1161,14 @@ async def index_google_drive_files( folder_tokens = connector.config.get("folder_tokens", {}) start_page_token = folder_tokens.get(target_folder_id) + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) can_use_delta = ( - use_delta_sync and start_page_token and connector.last_indexed_at + not is_composio_connector + and use_delta_sync + and start_page_token + and connector.last_indexed_at ) documents_unsupported = 0 @@ -1051,7 +1239,16 @@ async def index_google_drive_files( ) if documents_indexed > 0 or can_use_delta: - new_token, token_error = await get_start_page_token(drive_client) + if isinstance(drive_client, ComposioDriveClient): + ( + new_token, + token_error, + ) = await drive_client.composio.get_drive_start_page_token( + drive_client.connected_account_id, + drive_client.entity_id, + ) + else: + new_token, token_error = await get_start_page_token(drive_client) if new_token and not token_error: await session.refresh(connector) if "folder_tokens" not in connector.config: @@ -1137,32 +1334,17 @@ async def index_google_drive_single_file( ) return 0, error_msg - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, error_msg - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - "SECRET_KEY not configured but credentials are encrypted", - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - ) + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + await task_logger.log_task_failure( + log_entry, + client_error or "Failed to initialize Google Drive client", + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, client_error connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -1171,10 +1353,6 @@ async def index_google_drive_single_file( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - file, error = await get_file_by_id(drive_client, file_id) if error or not file: error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}" @@ -1276,32 +1454,18 @@ async def index_google_drive_selected_files( ) return 0, 0, [error_msg] - pre_built_credentials = None - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - connected_account_id = connector.config.get("composio_connected_account_id") - if not connected_account_id: - error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) - return 0, 0, [error_msg] - pre_built_credentials = build_composio_credentials(connected_account_id) - else: - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted and not config.SECRET_KEY: - error_msg = ( - "SECRET_KEY not configured but credentials are marked as encrypted" - ) - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing SECRET_KEY", - {"error_type": "MissingSecretKey"}, - ) - return 0, 0, [error_msg] + drive_client, client_error = _build_drive_client_for_connector( + session, connector_id, connector, user_id + ) + if client_error or not drive_client: + error_msg = client_error or "Failed to initialize Google Drive client" + await task_logger.log_task_failure( + log_entry, + error_msg, + "Missing connector credentials", + {"error_type": "ClientInitializationError"}, + ) + return 0, 0, [error_msg] connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) @@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files( from app.services.llm_service import get_vision_llm vision_llm = await get_vision_llm(session, search_space_id) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - indexed, skipped, unsupported, errors = await _index_selected_files( drive_client, session, diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index ef226087b..6697c0eb1 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import ( IndexingPipelineService, PlaceholderInfo, ) +from app.services.composio_service import ComposioService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.google_credentials import ( - COMPOSIO_GOOGLE_CONNECTOR_TYPES, - build_composio_credentials, -) +from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from .base import ( calculate_date_range, @@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]] HEARTBEAT_INTERVAL_SECONDS = 30 +def _normalize_composio_gmail_message(message: dict) -> dict: + if message.get("payload"): + return message + + headers = [] + header_values = { + "Subject": message.get("subject"), + "From": message.get("from") or message.get("sender"), + "To": message.get("to") or message.get("recipient"), + "Date": message.get("date"), + } + for name, value in header_values.items(): + if value: + headers.append({"name": name, "value": value}) + + return { + **message, + "id": message.get("id") + or message.get("message_id") + or message.get("messageId"), + "threadId": message.get("threadId") or message.get("thread_id"), + "payload": {"headers": headers}, + "snippet": message.get("snippet", ""), + "messageText": message.get("messageText") or message.get("body") or "", + } + + +def _format_gmail_message_to_markdown(message: dict) -> str: + headers = { + header.get("name", "").lower(): header.get("value", "") + for header in message.get("payload", {}).get("headers", []) + if isinstance(header, dict) + } + subject = headers.get("subject", "No Subject") + from_email = headers.get("from", "Unknown Sender") + to_email = headers.get("to", "Unknown Recipient") + date_str = headers.get("date", "Unknown Date") + message_text = ( + message.get("messageText") + or message.get("body") + or message.get("text") + or message.get("snippet", "") + ) + + return ( + f"# {subject}\n\n" + f"**From:** {from_email}\n" + f"**To:** {to_email}\n" + f"**Date:** {date_str}\n\n" + f"## Message Content\n\n{message_text}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {message.get('id', 'Unknown')}\n" + f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n" + ) + + def _build_connector_doc( message: dict, markdown_content: str, @@ -162,7 +216,14 @@ async def index_google_gmail_messages( ) return 0, 0, error_msg - # ── Credential building ─────────────────────────────────────── + is_composio_connector = ( + connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES + ) + gmail_connector = None + composio_service = None + connected_account_id = None + + # ── Credential/client building ──────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -173,7 +234,7 @@ async def index_google_gmail_messages( {"error_type": "MissingComposioAccount"}, ) return 0, 0, "Composio connected_account_id not found" - credentials = build_composio_credentials(connected_account_id) + composio_service = ComposioService() else: config_data = connector.config @@ -241,9 +302,10 @@ async def index_google_gmail_messages( {"stage": "client_initialization"}, ) - gmail_connector = GoogleGmailConnector( - credentials, session, user_id, connector_id - ) + if not is_composio_connector: + gmail_connector = GoogleGmailConnector( + credentials, session, user_id, connector_id + ) calculated_start_date, calculated_end_date = calculate_date_range( connector, start_date, end_date, default_days_back=365 @@ -254,11 +316,60 @@ async def index_google_gmail_messages( f"Fetching emails for connector {connector_id} " f"from {calculated_start_date} to {calculated_end_date}" ) - messages, error = await gmail_connector.get_recent_messages( - max_results=max_messages, - start_date=calculated_start_date, - end_date=calculated_end_date, - ) + if is_composio_connector: + query_parts = [] + if calculated_start_date: + query_parts.append(f"after:{calculated_start_date.replace('-', '/')}") + if calculated_end_date: + query_parts.append(f"before:{calculated_end_date.replace('-', '/')}") + query = " ".join(query_parts) + + messages = [] + page_token = None + error = None + while len(messages) < max_messages: + page_size = min(50, max_messages - len(messages)) + ( + page_messages, + page_token, + _estimate, + page_error, + ) = await composio_service.get_gmail_messages( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=page_size, + page_token=page_token, + ) + if page_error: + error = page_error + break + for page_message in page_messages: + message_id = ( + page_message.get("id") + or page_message.get("message_id") + or page_message.get("messageId") + ) + if message_id: + ( + detail, + detail_error, + ) = await composio_service.get_gmail_message_detail( + connected_account_id=connected_account_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if not detail_error and isinstance(detail, dict): + page_message = detail + messages.append(_normalize_composio_gmail_message(page_message)) + if not page_token: + break + else: + messages, error = await gmail_connector.get_recent_messages( + max_results=max_messages, + start_date=calculated_start_date, + end_date=calculated_end_date, + ) if error: error_message = error @@ -326,7 +437,12 @@ async def index_google_gmail_messages( documents_skipped += 1 continue - markdown_content = gmail_connector.format_message_to_markdown(message) + if is_composio_connector: + markdown_content = _format_gmail_message_to_markdown(message) + else: + markdown_content = gmail_connector.format_message_to_markdown( + message + ) if not markdown_content.strip(): logger.warning(f"Skipping message with no content: {message_id}") documents_skipped += 1 diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index 9258d5cfe..0693dfebb 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -51,22 +51,34 @@ class _FakeToolMessage: tool_call_id: str | None = None +@dataclass +class _FakeInterrupt: + value: dict[str, Any] + + +@dataclass +class _FakeTask: + interrupts: tuple[_FakeInterrupt, ...] = () + + class _FakeAgentState: """Stand-in for ``StateSnapshot`` returned by ``aget_state``.""" - def __init__(self) -> None: + def __init__(self, tasks: list[Any] | None = None) -> None: # Empty values keeps the cloud-fallback safety-net branch a no-op, - # and an empty ``tasks`` list keeps the post-stream interrupt - # check a no-op too. + # and empty ``tasks`` keep the post-stream interrupt check a no-op too. self.values: dict[str, Any] = {} - self.tasks: list[Any] = [] + self.tasks: list[Any] = tasks or [] class _FakeAgent: """Replays a list of ``astream_events`` events.""" - def __init__(self, events: list[dict[str, Any]]) -> None: + def __init__( + self, events: list[dict[str, Any]], state: _FakeAgentState | None = None + ) -> None: self._events = events + self._state = state or _FakeAgentState() async def astream_events( # type: ignore[no-untyped-def] self, _input_data: Any, *, config: dict[str, Any], version: str @@ -79,7 +91,7 @@ class _FakeAgent: # Called once after astream_events drains so the cloud-fallback # safety net can inspect staged filesystem work. The fake stays # empty so the safety net is a no-op. - return _FakeAgentState() + return self._state def _model_stream( @@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: ) -async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]: +async def _drain( + events: list[dict[str, Any]], state: _FakeAgentState | None = None +) -> list[dict[str, Any]]: """Run ``_stream_agent_events`` against a fake agent and return the SSE payloads (parsed JSON) it yielded. """ - agent = _FakeAgent(events) + agent = _FakeAgent(events, state=state) service = VercelStreamingService() result = StreamResult() config = {"configurable": {"thread_id": "test-thread"}} @@ -525,3 +539,29 @@ async def test_unmatched_fallback_still_attaches_lc_id( assert len(starts) == 1 assert starts[0]["toolCallId"].startswith("call_run-1") assert starts[0]["langchainToolCallId"] == "lc-orphan" + + +@pytest.mark.asyncio +async def test_interrupt_request_uses_task_that_contains_interrupt( + parity_v2_on: None, +) -> None: + interrupt_payload = { + "type": "calendar_event_create", + "action": { + "tool": "create_calendar_event", + "params": {"summary": "mom bday"}, + }, + "context": {}, + } + state = _FakeAgentState( + tasks=[ + _FakeTask(interrupts=()), + _FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)), + ] + ) + + payloads = await _drain([], state=state) + + interrupts = _of_type(payloads, "data-interrupt-request") + assert len(interrupts) == 1 + assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event" From bdb97a0888543ea5d5b8b3902efe1c3a808abf3f Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 22:25:04 -0700 Subject: [PATCH 296/299] chore: linting --- surfsense_backend/app/tasks/chat/stream_new_chat.py | 1 + .../tests/unit/tasks/chat/test_tool_input_streaming.py | 4 +++- .../config/connector-status-config.json | 10 ++++++++++ .../connector-popup/constants/connector-constants.ts | 4 ++-- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5eb35f8b1..268a4401e 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -98,6 +98,7 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int: def _first_interrupt_value(state: Any) -> dict[str, Any] | None: """Return the first LangGraph interrupt payload across all snapshot tasks.""" + def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None: if isinstance(candidate, dict): value = candidate.get("value", candidate) diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index 0693dfebb..60750396c 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -564,4 +564,6 @@ async def test_interrupt_request_uses_task_that_contains_interrupt( interrupts = _of_type(payloads, "data-interrupt-request") assert len(interrupts) == 1 - assert interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event" + assert ( + interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event" + ) diff --git a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json index f62758256..b4e85eab0 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json +++ b/surfsense_web/components/assistant-ui/connector-popup/config/connector-status-config.json @@ -9,6 +9,16 @@ "enabled": true, "status": "warning", "statusMessage": "Some requests may be blocked if not using Firecrawl." + }, + "JIRA_CONNECTOR": { + "enabled": false, + "status": "maintenance", + "statusMessage": "Rework in progress." + }, + "CONFLUENCE_CONNECTOR": { + "enabled": false, + "status": "maintenance", + "statusMessage": "Rework in progress." } }, "globalSettings": { diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index ae2c413cf..2f9605ea7 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -105,14 +105,14 @@ export const OAUTH_CONNECTORS = [ { id: "jira-connector", title: "Jira", - description: "Search, read, and manage issues", + description: "Rework in progress.", connectorType: EnumConnectorName.JIRA_CONNECTOR, authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", }, { id: "confluence-connector", title: "Confluence", - description: "Search documentation", + description: "Rework in progress.", connectorType: EnumConnectorName.CONFLUENCE_CONNECTOR, authEndpoint: "/api/v1/auth/confluence/connector/add/", }, From c938d39277225f425cda24ea56ca50a0ed93e30a Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 23:10:48 -0700 Subject: [PATCH 297/299] feat: moved most things behind correct feature flag --- docker/.env.example | 18 +++ .../app/agents/new_chat/feature_flags.py | 106 +++++++++++------- .../app/routes/agent_flags_route.py | 8 +- .../app/services/auto_model_pin_service.py | 2 +- .../agents/new_chat/test_feature_flags.py | 38 ++++--- .../services/test_auto_model_pin_service.py | 55 ++++++++- surfsense_web/app/(home)/pricing/page.tsx | 2 +- .../new-chat/[[...chat_id]]/page.tsx | 18 ++- .../components/AgentStatusContent.tsx | 13 +++ .../layout/ui/sidebar/DocumentsSidebar.tsx | 10 +- .../components/pricing/pricing-section.tsx | 37 +++--- surfsense_web/lib/agent-filesystem.ts | 13 ++- .../lib/apis/agent-flags-api.service.ts | 2 + 13 files changed, 237 insertions(+), 85 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index c2e87a619..fd56bdccc 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -308,6 +308,24 @@ STT_SERVICE=local/base # Advanced (optional) # ------------------------------------------------------------------------------ +# New-chat agent feature flags +SURFSENSE_ENABLE_CONTEXT_EDITING=true +SURFSENSE_ENABLE_COMPACTION_V2=true +SURFSENSE_ENABLE_RETRY_AFTER=true +SURFSENSE_ENABLE_MODEL_FALLBACK=false +SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true +SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true +SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true +SURFSENSE_ENABLE_BUSY_MUTEX=true +SURFSENSE_ENABLE_SKILLS=true +SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true +SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true +SURFSENSE_ENABLE_ACTION_LOG=true +SURFSENSE_ENABLE_REVERT_ROUTE=true +SURFSENSE_ENABLE_PERMISSION=true +SURFSENSE_ENABLE_DOOM_LOOP=true +SURFSENSE_ENABLE_STREAM_PARITY_V2=true + # Periodic connector sync interval (default: 5m) # SCHEDULE_CHECKER_INTERVAL=5m diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index f58bf0dd7..5007d89a5 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack. These flags gate the newer agent middleware (some ported from OpenCode, some sourced from ``langchain.agents.middleware`` / ``deepagents``, some -SurfSense-native). They follow a "default-OFF for risky things, -default-ON for safe upgrades, master kill-switch for everything new" model. +SurfSense-native). Most shipped agent-stack upgrades default ON so Docker +image updates work even when older installs do not have newly introduced +environment variables. Risky/experimental integrations stay default OFF, +and the master kill-switch can still disable everything new. All new middleware checks its flag at agent build time. If the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new @@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior. Examples -------- -Local development (recommended for trying everything except doom-loop / selector): +Defaults: SURFSENSE_ENABLE_CONTEXT_EDITING=true SURFSENSE_ENABLE_COMPACTION_V2=true SURFSENSE_ENABLE_RETRY_AFTER=true + SURFSENSE_ENABLE_MODEL_FALLBACK=false + SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true + SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true - SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy - SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships - SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false - SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events + SURFSENSE_ENABLE_PERMISSION=true + SURFSENSE_ENABLE_DOOM_LOOP=true + SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call + SURFSENSE_ENABLE_STREAM_PARITY_V2=true Master kill-switch (overrides everything else): @@ -60,32 +65,28 @@ class AgentFeatureFlags: disable_new_agent_stack: bool = False # Agent quality — context budget, retry/limits, name-repair, doom-loop - enable_context_editing: bool = False - enable_compaction_v2: bool = False - enable_retry_after: bool = False + enable_context_editing: bool = True + enable_compaction_v2: bool = True + enable_retry_after: bool = True enable_model_fallback: bool = False - enable_model_call_limit: bool = False - enable_tool_call_limit: bool = False - enable_tool_call_repair: bool = False - enable_doom_loop: bool = ( - False # Default OFF until UI handles permission='doom_loop' - ) + enable_model_call_limit: bool = True + enable_tool_call_limit: bool = True + enable_tool_call_repair: bool = True + enable_doom_loop: bool = True # Safety — permissions, concurrency, tool-set narrowing - enable_permission: bool = False # Default OFF for first deploy - enable_busy_mutex: bool = False + enable_permission: bool = True + enable_busy_mutex: bool = True enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost # Skills + subagents - enable_skills: bool = False - enable_specialized_subagents: bool = False - enable_kb_planner_runnable: bool = False + enable_skills: bool = True + enable_specialized_subagents: bool = True + enable_kb_planner_runnable: bool = True # Snapshot / revert - enable_action_log: bool = False - enable_revert_route: bool = ( - False # Backend ships before UI; route returns 503 until this flips - ) + enable_action_log: bool = True + enable_revert_route: bool = True # Streaming parity v2 — opt in to LangChain's structured # ``AIMessageChunk`` content (typed reasoning blocks, tool-input @@ -94,7 +95,7 @@ class AgentFeatureFlags: # text path and the synthetic ``call_<run_id>`` tool-call id (no # ``langchainToolCallId`` propagation). Schema migrations 135/136 # ship unconditionally because they're forward-compatible. - enable_stream_parity_v2: bool = False + enable_stream_parity_v2: bool = True # Plugins enable_plugin_loader: bool = False @@ -115,43 +116,64 @@ class AgentFeatureFlags: "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " "middleware is forced OFF for this build." ) - return cls(disable_new_agent_stack=True) + return cls( + disable_new_agent_stack=True, + enable_context_editing=False, + enable_compaction_v2=False, + enable_retry_after=False, + enable_model_fallback=False, + enable_model_call_limit=False, + enable_tool_call_limit=False, + enable_tool_call_repair=False, + enable_doom_loop=False, + enable_permission=False, + enable_busy_mutex=False, + enable_llm_tool_selector=False, + enable_skills=False, + enable_specialized_subagents=False, + enable_kb_planner_runnable=False, + enable_action_log=False, + enable_revert_route=False, + enable_stream_parity_v2=False, + enable_plugin_loader=False, + enable_otel=False, + ) return cls( disable_new_agent_stack=False, # Agent quality - enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), - enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), - enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), + enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True), + enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True), + enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), enable_model_call_limit=_env_bool( - "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True ), - enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), + enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True), enable_tool_call_repair=_env_bool( - "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True ), - enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), + enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True), # Safety - enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), - enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), + enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True), + enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True), enable_llm_tool_selector=_env_bool( "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False ), # Skills + subagents - enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), + enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True), enable_specialized_subagents=_env_bool( - "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True ), enable_kb_planner_runnable=_env_bool( - "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True ), # Snapshot / revert - enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), - enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True), + enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True), # Streaming parity v2 enable_stream_parity_v2=_env_bool( - "SURFSENSE_ENABLE_STREAM_PARITY_V2", False + "SURFSENSE_ENABLE_STREAM_PARITY_V2", True ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index 5732a8dfb..99388af66 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags +from app.config import config from app.db import User from app.users import current_active_user @@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel): enable_otel: bool + enable_desktop_local_filesystem: bool + @classmethod def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead: # asdict() avoids missing-field bugs when AgentFeatureFlags grows. - return cls(**asdict(flags)) + return cls( + **asdict(flags), + enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM, + ) @router.get("/agent/flags", response_model=AgentFeatureFlagsRead) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 4f045ba02..185035b8a 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -399,7 +399,7 @@ async def resolve_or_get_pinned_llm_config_id( False if force_repin_free else await _is_premium_eligible(session, user_id) ) if premium_eligible: - eligible = candidates + eligible = [c for c in candidates if _tier_of(c) == "premium"] else: eligible = [c for c in candidates if _tier_of(c) != "premium"] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index 38a70a443..df60a4816 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -31,18 +31,38 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "SURFSENSE_ENABLE_ACTION_LOG", "SURFSENSE_ENABLE_REVERT_ROUTE", + "SURFSENSE_ENABLE_STREAM_PARITY_V2", "SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_OTEL", ]: monkeypatch.delenv(name, raising=False) -def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None: +def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None: _clear_all(monkeypatch) flags = reload_for_tests() assert isinstance(flags, AgentFeatureFlags) assert flags.disable_new_agent_stack is False - assert flags.any_new_middleware_enabled() is False + assert flags.enable_context_editing is True + assert flags.enable_compaction_v2 is True + assert flags.enable_retry_after is True + assert flags.enable_model_fallback is False + assert flags.enable_model_call_limit is True + assert flags.enable_tool_call_limit is True + assert flags.enable_tool_call_repair is True + assert flags.enable_doom_loop is True + assert flags.enable_permission is True + assert flags.enable_busy_mutex is True + assert flags.enable_llm_tool_selector is False + assert flags.enable_skills is True + assert flags.enable_specialized_subagents is True + assert flags.enable_kb_planner_runnable is True + assert flags.enable_action_log is True + assert flags.enable_revert_route is True + assert flags.enable_stream_parity_v2 is True + assert flags.enable_plugin_loader is False + assert flags.enable_otel is False + assert flags.any_new_middleware_enabled() is True def test_master_kill_switch_overrides_individual_flags( @@ -100,21 +120,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", + "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2", "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", "enable_otel": "SURFSENSE_ENABLE_OTEL", } - # `enable_otel` is intentionally orthogonal — it does NOT count toward - # ``any_new_middleware_enabled`` because OTel is observability-only and - # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement. - counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"} - for attr, env_name in flag_to_env.items(): _clear_all(monkeypatch) - monkeypatch.setenv(env_name, "true") + monkeypatch.setenv(env_name, "false") flags = reload_for_tests() - assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}" - if attr in counts_toward_middleware: - assert flags.any_new_middleware_enabled() is True - else: - assert flags.any_new_middleware_enabled() is False + assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}" diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 49b3621c7..c8d6dc1ca 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -101,11 +101,58 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): user_id="00000000-0000-0000-0000-000000000001", selected_llm_config_id=0, ) - assert result.resolved_llm_config_id in {-1, -2} + assert result.resolved_llm_config_id == -1 assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id assert session.commit_count == 1 +@pytest.mark.asyncio +async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + "quality_score": 100, + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + "quality_score": 10, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + @pytest.mark.asyncio async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config @@ -361,12 +408,12 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): ], ) - async def _allowed(*_args, **_kwargs): - return _FakeQuotaResult(allowed=True) + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", - _allowed, + _blocked, ) result = await resolve_or_get_pinned_llm_config_id( diff --git a/surfsense_web/app/(home)/pricing/page.tsx b/surfsense_web/app/(home)/pricing/page.tsx index 6f332be70..2a413b9a9 100644 --- a/surfsense_web/app/(home)/pricing/page.tsx +++ b/surfsense_web/app/(home)/pricing/page.tsx @@ -5,7 +5,7 @@ import { BreadcrumbNav } from "@/components/seo/breadcrumb-nav"; export const metadata: Metadata = { title: "Pricing | SurfSense - Free AI Search Plans", description: - "Explore SurfSense plans and pricing. Start free with 500 pages & $5 of premium credit. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost — $1 buys $1 of credit.", + "Explore SurfSense plans and pricing. Start free with 500 pages & $5 in premium credits. Use ChatGPT, Claude AI, and premium AI models. Pay as you go at provider cost.", alternates: { canonical: "https://surfsense.com/pricing", }, diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 39201e5cc..4c8e4fe93 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -13,6 +13,7 @@ import { useParams } from "next/navigation"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; import { clearTargetCommentIdAtom, @@ -393,6 +394,8 @@ export default function NewChatPage() { // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); + const { data: agentFlags } = useAtomValue(agentFlagsAtom); + const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true; // Live collaboration: sync session state and messages via Zero useChatSessionStateSync(threadId); @@ -989,7 +992,9 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); if ( selection.filesystem_mode === "desktop_local_folder" && (!selection.local_filesystem_mounts || selection.local_filesystem_mounts.length === 0) @@ -1311,6 +1316,7 @@ export default function NewChatPage() { setAgentCreatedDocuments, queryClient, currentUser, + localFilesystemEnabled, disabledTools, updateChatTabTitle, tokenUsageStore, @@ -1413,7 +1419,9 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); const response = await fetchWithTurnCancellingRetry(() => fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { method: "POST", @@ -1561,6 +1569,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + localFilesystemEnabled, queryClient, tokenUsageStore, fetchWithTurnCancellingRetry, @@ -1746,7 +1755,9 @@ export default function NewChatPage() { ? messageDocumentsMap[sourceUserMessageId] : []; try { - const selection = await getAgentFilesystemSelection(searchSpaceId); + const selection = await getAgentFilesystemSelection(searchSpaceId, { + localFilesystemEnabled, + }); const requestBody: Record<string, unknown> = { search_space_id: searchSpaceId, user_query: newUserQuery, @@ -2016,6 +2027,7 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, + localFilesystemEnabled, messageDocumentsMap, setMessageDocumentsMap, queryClient, diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx index bd8f03a70..17d8aa50c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx @@ -178,6 +178,19 @@ const FLAG_GROUPS: FlagGroup[] = [ }, ], }, + { + id: "desktop", + title: "Desktop", + subtitle: "Desktop-only capabilities exposed by the backend deployment.", + flags: [ + { + key: "enable_desktop_local_filesystem", + label: "Local filesystem", + description: "Allow Desktop chat sessions to operate directly on selected local folders.", + envVar: "ENABLE_DESKTOP_LOCAL_FILESYSTEM", + }, + ], + }, ]; function FlagRow({ def, value }: { def: FlagDef; value: boolean }) { diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index bf4de6454..8d59363a6 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -23,6 +23,7 @@ import { useTranslations } from "next-intl"; import type React from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; @@ -197,6 +198,7 @@ function AuthenticatedDocumentsSidebarBase({ const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const setRightPanelCollapsed = useSetAtom(rightPanelCollapsedAtom); const openEditorPanel = useSetAtom(openEditorPanelAtom); + const { data: agentFlags } = useAtomValue(agentFlagsAtom); const { data: connectors } = useAtomValue(connectorsAtom); const connectorCount = connectors?.length ?? 0; @@ -209,6 +211,7 @@ function AuthenticatedDocumentsSidebarBase({ const [watchedFolderIds, setWatchedFolderIds] = useState<Set<number>>(new Set()); const [folderWatchOpen, setFolderWatchOpen] = useAtom(folderWatchDialogOpenAtom); const [watchInitialFolder, setWatchInitialFolder] = useAtom(folderWatchInitialFolderAtom); + const localFilesystemEnabled = agentFlags?.enable_desktop_local_filesystem === true; const isElectron = desktopFeaturesEnabled && typeof window !== "undefined" && !!window.electronAPI; @@ -1036,9 +1039,12 @@ function AuthenticatedDocumentsSidebarBase({ return () => document.removeEventListener("keydown", handleEscape); }, [open, onOpenChange, isMobile, setRightPanelCollapsed]); - const showFilesystemTabs = !isMobile && !!electronAPI && !!filesystemSettings; + const showFilesystemTabs = + !isMobile && !!electronAPI && !!filesystemSettings && localFilesystemEnabled; const currentFilesystemTab = - filesystemSettings?.mode === "desktop_local_folder" ? "local" : "cloud"; + localFilesystemEnabled && filesystemSettings?.mode === "desktop_local_folder" + ? "local" + : "cloud"; const showCloudSkeleton = currentFilesystemTab === "cloud" && (zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete"); diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 156ef9134..4ba1ecc1e 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -12,11 +12,11 @@ const demoPlans = [ price: "0", yearlyPrice: "0", period: "", - billingText: "500 pages + $5 of premium credit included", + billingText: "500 pages + $5 in premium credits included", features: [ "Self Hostable", "500 pages included to start", - "$5 of premium credit to start, billed at provider cost", + "$5 in premium credits for paid AI models and premium AI features", "Includes access to OpenAI text, audio and image models", "Realtime Collaborative Group Chats with teammates", "Community support on Discord", @@ -35,7 +35,7 @@ const demoPlans = [ features: [ "Everything in Free", "Buy 1,000-page packs at $1 each", - "Top up premium credit at $1 per $1 of credit, billed at provider cost", + "Top up premium credits at $1 per $1 of credit, billed at provider cost", "Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter", "Priority support on Discord", ], @@ -89,7 +89,7 @@ const faqData: FAQSection[] = [ { question: "What are Basic and Premium processing modes?", answer: - "When uploading documents, you can choose between two processing modes. Basic mode uses standard extraction and costs 1 page credit per page, great for most documents. Premium mode uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables, layouts, and formatting. Premium costs 10 page credits per page but delivers significantly higher fidelity output for these specialized document types.", + "When uploading documents, you can choose between two processing modes. Basic mode uses standard extraction and costs 1 page credit per page, great for most documents. Premium processing mode uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables, layouts, and formatting. It costs 10 page credits per page and does not use your premium AI credits.", }, { question: "How does the Pay As You Go plan work?", @@ -129,27 +129,32 @@ const faqData: FAQSection[] = [ ], }, { - title: "Premium Credit", + title: "Premium Credits", items: [ { - question: 'What is "premium credit"?', + question: 'What are "premium credits"?', answer: - "Premium credit is your USD balance for using premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro in SurfSense. Each AI request debits the actual USD cost the provider charges, so cheap and expensive models bill proportionally. Non-premium models (such as the free-tier models available without login) don't touch your premium credit.", + "Premium credits are your USD balance for paid AI usage in SurfSense, including premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro, plus premium AI features such as image generation, podcasts, and video presentations when they use paid models. Each request debits the actual USD provider cost, so cheaper and more expensive models bill proportionally.", }, { - question: "How much premium credit do I get for free?", + question: "How many premium credits do I get for free?", answer: - "Every registered SurfSense account starts with $5 of premium credit at no cost. Anonymous users (no login) get 500,000 free tokens across all free models. Once your free credit runs out, you can top up at any time.", + "Every registered SurfSense account starts with $5 in premium credits at no cost. Anonymous users (no login) get 500,000 free tokens across free models before creating an account. Once your included premium credits run out, you can top up at any time.", }, { - question: "How does buying premium credit work?", + question: "How does buying premium credits work?", answer: - "Just like pages, there's no subscription. Top-ups buy $1 of credit for $1 — every cent you pay is spent at provider cost, no markup. Purchased credit is added to your account immediately. You can buy up to $100 at a time.", + "Premium credit top-ups are pay as you go, with no subscription. $1 buys $1 of credit, and your balance is spent at provider cost. Purchased credit is added to your account immediately. You can buy up to $100 at a time.", }, { - question: "What happens if I run out of premium credit?", + question: "Are premium credits the same as page credits?", answer: - "When your premium credit balance runs low (below 20%), you'll see a warning. Once you run out, premium model requests are paused until you top up. You can always switch to non-premium models, which don't touch your premium credit.", + "No. Page credits pay for document indexing and file-based connector processing. Premium credits pay for paid AI usage, such as premium model chats and premium AI generation features. Premium document processing mode sounds similar, but it consumes page credits, not premium credits.", + }, + { + question: "What happens if I run out of premium credits?", + answer: + "When your premium credit balance runs low, you'll see a warning. Once you run out, paid model requests and premium AI features are paused until you top up. You can still use non-premium models and features that do not consume premium credits.", }, ], }, @@ -159,7 +164,7 @@ const faqData: FAQSection[] = [ { question: "Can I self-host SurfSense with unlimited pages and credit?", answer: - "Yes! When self-hosting, you have full control over your page and premium-credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credit, so you can index as much data and use as many AI queries as your infrastructure supports.", + "Yes! When self-hosting, you have full control over your page and premium credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credits, so you can index as much data and use as many AI queries as your infrastructure supports.", }, ], }, @@ -250,7 +255,7 @@ function PricingFAQ() { Frequently Asked Questions </h2> <p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground"> - Everything you need to know about SurfSense pages, premium credit, and billing. Can't + Everything you need to know about SurfSense pages, premium credits, and billing. Can't find what you need? Reach out at{" "} <a href="mailto:rohan@surfsense.com" className="text-blue-500 underline"> rohan@surfsense.com @@ -335,7 +340,7 @@ function PricingBasic() { <Pricing plans={demoPlans} title="SurfSense Pricing" - description="Start free with 500 pages & $5 of premium credit. Pay as you go, billed at provider cost." + description="Start free with 500 pages & $5 in premium credits. Pay as you go." /> <PricingFAQ /> </> diff --git a/surfsense_web/lib/agent-filesystem.ts b/surfsense_web/lib/agent-filesystem.ts index da5fc1b1d..5f8066d27 100644 --- a/surfsense_web/lib/agent-filesystem.ts +++ b/surfsense_web/lib/agent-filesystem.ts @@ -12,6 +12,10 @@ export interface AgentFilesystemSelection { local_filesystem_mounts?: AgentFilesystemMountSelection[]; } +export interface AgentFilesystemSelectionOptions { + localFilesystemEnabled: boolean; +} + const DEFAULT_SELECTION: AgentFilesystemSelection = { filesystem_mode: "cloud", client_platform: "web", @@ -23,10 +27,15 @@ export function getClientPlatform(): ClientPlatform { } export async function getAgentFilesystemSelection( - searchSpaceId?: number | null + searchSpaceId?: number | null, + options?: AgentFilesystemSelectionOptions ): Promise<AgentFilesystemSelection> { const platform = getClientPlatform(); - if (platform !== "desktop" || !window.electronAPI?.getAgentFilesystemSettings) { + if ( + platform !== "desktop" || + !options?.localFilesystemEnabled || + !window.electronAPI?.getAgentFilesystemSettings + ) { return { ...DEFAULT_SELECTION, client_platform: platform }; } try { diff --git a/surfsense_web/lib/apis/agent-flags-api.service.ts b/surfsense_web/lib/apis/agent-flags-api.service.ts index 87332ca9f..534810c0e 100644 --- a/surfsense_web/lib/apis/agent-flags-api.service.ts +++ b/surfsense_web/lib/apis/agent-flags-api.service.ts @@ -27,6 +27,8 @@ const AgentFeatureFlagsSchema = z.object({ enable_plugin_loader: z.boolean(), enable_otel: z.boolean(), + + enable_desktop_local_filesystem: z.boolean(), }); export type AgentFeatureFlags = z.infer<typeof AgentFeatureFlagsSchema>; From e4f9d79635d827adcc7e70279ccec9ac31482fa1 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 23:35:47 -0700 Subject: [PATCH 298/299] feat: add preferred premium auto configuration logic and corresponding tests --- .../app/services/auto_model_pin_service.py | 15 ++++- .../services/test_auto_model_pin_service.py | 58 +++++++++++++++++++ .../components/pricing/pricing-section.tsx | 3 +- 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 185035b8a..9bbca8669 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -220,6 +220,15 @@ def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() +def _is_preferred_premium_auto_config(cfg: dict) -> bool: + """Return True for the operator-preferred premium Auto model.""" + return ( + _tier_of(cfg) == "premium" + and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + and str(cfg.get("model_name", "")).lower() == "gpt-5.4" + ) + + def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: """Pick a config with quality-first ranking + deterministic spread. @@ -399,7 +408,11 @@ async def resolve_or_get_pinned_llm_config_id( False if force_repin_free else await _is_premium_eligible(session, user_id) ) if premium_eligible: - eligible = [c for c in candidates if _tier_of(c) == "premium"] + premium_candidates = [c for c in candidates if _tier_of(c) == "premium"] + preferred_premium = [ + c for c in premium_candidates if _is_preferred_premium_auto_config(c) + ] + eligible = preferred_premium or premium_candidates else: eligible = [c for c in candidates if _tier_of(c) != "premium"] diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index c8d6dc1ca..d1af29aeb 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -153,6 +153,64 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): assert result.resolved_tier == "premium" +@pytest.mark.asyncio +async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.1", + "api_key": "k1", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5.4", + "api_key": "k2", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + }, + { + "id": -3, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5.4", + "api_key": "k3", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 100, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "premium" + + @pytest.mark.asyncio async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 4ba1ecc1e..07c11b4d6 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -34,8 +34,7 @@ const demoPlans = [ billingText: "No subscription, buy only when you need more", features: [ "Everything in Free", - "Buy 1,000-page packs at $1 each", - "Top up premium credits at $1 per $1 of credit, billed at provider cost", + "Buy 1,000-page packs or $1 in premium credits at $1 each", "Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter", "Priority support on Discord", ], From 30d06affdc7deba50fbe00a38ca2dd4ae564394d Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Sat, 2 May 2026 23:40:44 -0700 Subject: [PATCH 299/299] chore: bumped version to 0.0.20 --- VERSION | 2 +- surfsense_backend/pyproject.toml | 2 +- surfsense_backend/uv.lock | 2 +- surfsense_browser_extension/package.json | 2 +- surfsense_desktop/package.json | 2 +- surfsense_web/package.json | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/VERSION b/VERSION index 44517d518..fe04e7f67 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.19 +0.0.20 diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index cd683e2e1..b9c389734 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.19" +version = "0.0.20" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index efe670d05..46dd0b613 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -7947,7 +7947,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.19" +version = "0.0.20" source = { editable = "." } dependencies = [ { name = "alembic" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index 146dd177e..1ffc4dd87 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.19", + "version": "0.0.20", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index e2712d8ea..960267e16 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,6 +1,6 @@ { "name": "surfsense-desktop", - "version": "0.0.19", + "version": "0.0.20", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_web/package.json b/surfsense_web/package.json index 41175daeb..399544019 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -1,6 +1,6 @@ { "name": "surfsense_web", - "version": "0.0.19", + "version": "0.0.20", "private": true, "description": "SurfSense Frontend", "scripts": {